mirror of https://github.com/microsoft/autogen.git
+mdb atlas vectordb [clean_final] (#3000)
* +mdb atlas * Update test/agentchat/contrib/vectordb/test_mongodb.py Co-authored-by: HRUSHIKESH DOKALA <96101829+Hk669@users.noreply.github.com> * update test_mongodb.py; we dont need to do the assert .collection_name vs .name * Try fix mongodb service * Try fix mongodb service * Update username and password * Update autogen/agentchat/contrib/vectordb/mongodb.py * closer --- but im not super thrilled about the solution... * PYTHON-4506 Expanded tests and simplified vector search pipelines * Update mongodb.py * Update mongodb.py - Casey * search_index_magic index_name change; keeping track of lucene indexes is tricky * Fix format * Fix tests * hacking trying to figure this out * Streamline checks for indexes in construction and restructure tests * Add tests for score_threshold, embedding inclusion, and multiple query tests * refactored create_collection to meet base object requirements * lint * change the localhost port to 27017 * add test to check that no embedding is there unless explicitly provided * Update logger * Add test get docs with ids=None * Rename and update notebook * have index management include waiting behaviors * Adds further optional waits or users and tests. Cleans up upsert. * ensure the embedding size for multiple embedding inputs is equal to dimensions * fix up tests and add configuration to ensure documents and indexes are READY for querying * fix import failure * adjust typing for 3.9 * fix up the notebook output * changed language to communicate time taken on first init_chat call * replace environment variable usage --------- Co-authored-by: Fabian Valle <fabian.valle-simmons@mongodb.com> Co-authored-by: HRUSHIKESH DOKALA <96101829+Hk669@users.noreply.github.com> Co-authored-by: Li Jiang <bnujli@gmail.com> Co-authored-by: Casey Clements <casey.clements@mongodb.com> Co-authored-by: Jib <jib.adegunloye@mongodb.com> Co-authored-by: Jib <Jibzade@gmail.com> Co-authored-by: Cozypet <yanhan860711@gmail.com>
This commit is contained in:
parent
1bd2124ba4
commit
f9295c4c39
|
@ -87,6 +87,10 @@ jobs:
|
|||
--health-retries 5
|
||||
ports:
|
||||
- 5432:5432
|
||||
mongodb:
|
||||
image: mongodb/mongodb-atlas-local:latest
|
||||
ports:
|
||||
- 27017:27017
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -104,6 +108,9 @@ jobs:
|
|||
- name: Install pgvector when on linux
|
||||
run: |
|
||||
pip install -e .[retrievechat-pgvector]
|
||||
- name: Install mongodb when on linux
|
||||
run: |
|
||||
pip install -e .[retrievechat-mongodb]
|
||||
- name: Install unstructured when python-version is 3.9 and on linux
|
||||
if: matrix.python-version == '3.9'
|
||||
run: |
|
||||
|
|
|
@ -186,7 +186,8 @@ class VectorDB(Protocol):
|
|||
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
include: List[str] | The fields to include. Default is None.
|
||||
If None, will include ["metadatas", "documents"], ids will always be included.
|
||||
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
|
||||
depending on the implementation.
|
||||
kwargs: dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
|
@ -200,7 +201,7 @@ class VectorDBFactory:
|
|||
Factory class for creating vector databases.
|
||||
"""
|
||||
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
|
||||
|
||||
@staticmethod
|
||||
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
|
||||
|
@ -222,6 +223,10 @@ class VectorDBFactory:
|
|||
from .pgvectordb import PGVectorDB
|
||||
|
||||
return PGVectorDB(**kwargs)
|
||||
if db_type.lower() in ["mdb", "mongodb", "atlas"]:
|
||||
from .mongodb import MongoDBAtlasVectorDB
|
||||
|
||||
return MongoDBAtlasVectorDB(**kwargs)
|
||||
if db_type.lower() in ["qdrant", "qdrantdb"]:
|
||||
from .qdrant import QdrantVectorDB
|
||||
|
||||
|
|
|
@ -0,0 +1,553 @@
|
|||
from copy import deepcopy
|
||||
from time import monotonic, sleep
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pymongo import MongoClient, UpdateOne, errors
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.driver_info import DriverInfo
|
||||
from pymongo.operations import SearchIndexModel
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from .base import Document, ItemID, QueryResults, VectorDB
|
||||
from .utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_INSERT_BATCH_SIZE = 100_000
|
||||
_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
|
||||
_DELAY = 0.5
|
||||
|
||||
|
||||
def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
|
||||
"""Utility changes _id field from Collection into id for Document."""
|
||||
return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]
|
||||
|
||||
|
||||
class MongoDBAtlasVectorDB(VectorDB):
|
||||
"""
|
||||
A Collection object for MongoDB.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str = "",
|
||||
database_name: str = "vector_db",
|
||||
embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
|
||||
collection_name: str = None,
|
||||
index_name: str = "vector_index",
|
||||
overwrite: bool = False,
|
||||
wait_until_index_ready: float = None,
|
||||
wait_until_document_ready: float = None,
|
||||
):
|
||||
"""
|
||||
Initialize the vector database.
|
||||
|
||||
Args:
|
||||
connection_string: str | The MongoDB connection string to connect to. Default is ''.
|
||||
database_name: str | The name of the database. Default is 'vector_db'.
|
||||
embedding_function: Callable | The embedding function used to generate the vector representation.
|
||||
collection_name: str | The name of the collection to create for this vector database
|
||||
Defaults to None
|
||||
index_name: str | Index name for the vector database, defaults to 'vector_index'
|
||||
overwrite: bool = False
|
||||
wait_until_index_ready: float | None | Blocking call to wait until the
|
||||
database indexes are ready. None, the default, means no wait.
|
||||
wait_until_document_ready: float | None | Blocking call to wait until the
|
||||
database indexes are ready. None, the default, means no wait.
|
||||
"""
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
self._wait_until_index_ready = wait_until_index_ready
|
||||
self._wait_until_document_ready = wait_until_document_ready
|
||||
|
||||
# This will get the model dimension size by computing the embeddings dimensions
|
||||
self.dimensions = self._get_embedding_size()
|
||||
|
||||
try:
|
||||
self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen"))
|
||||
self.client.admin.command("ping")
|
||||
logger.debug("Successfully created MongoClient")
|
||||
except errors.ServerSelectionTimeoutError as err:
|
||||
raise ConnectionError("Could not connect to MongoDB server") from err
|
||||
|
||||
self.db = self.client[database_name]
|
||||
logger.debug(f"Atlas Database name: {self.db.name}")
|
||||
if collection_name:
|
||||
self.active_collection = self.create_collection(collection_name, overwrite)
|
||||
else:
|
||||
self.active_collection = None
|
||||
|
||||
def _is_index_ready(self, collection: Collection, index_name: str):
|
||||
"""Check for the index name in the list of available search indexes to see if the
|
||||
specified index is of status READY
|
||||
|
||||
Args:
|
||||
collection (Collection): MongoDB Collection to for the search indexes
|
||||
index_name (str): Vector Search Index name
|
||||
|
||||
Returns:
|
||||
bool : True if the index is present and READY false otherwise
|
||||
"""
|
||||
for index in collection.list_search_indexes(index_name):
|
||||
if index["type"] == "vectorSearch" and index["status"] == "READY":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"):
|
||||
"""Waits for the index action to be completed. Otherwise throws a TimeoutError.
|
||||
|
||||
Timeout set on instantiation.
|
||||
action: "create" or "delete"
|
||||
"""
|
||||
assert action in ["create", "delete"], f"{action=} must be create or delete."
|
||||
start = monotonic()
|
||||
while monotonic() - start < self._wait_until_index_ready:
|
||||
if action == "create" and self._is_index_ready(collection, index_name):
|
||||
return
|
||||
elif action == "delete" and len(list(collection.list_search_indexes())) == 0:
|
||||
return
|
||||
sleep(_DELAY)
|
||||
|
||||
raise TimeoutError(f"Index {self.index_name} is not ready!")
|
||||
|
||||
def _wait_for_document(self, collection: Collection, index_name: str, doc: Document):
|
||||
start = monotonic()
|
||||
while monotonic() - start < self._wait_until_document_ready:
|
||||
query_result = _vector_search(
|
||||
embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
|
||||
n_results=1,
|
||||
collection=collection,
|
||||
index_name=index_name,
|
||||
)
|
||||
if query_result and query_result[0][0]["_id"] == doc["id"]:
|
||||
return
|
||||
sleep(_DELAY)
|
||||
|
||||
raise TimeoutError(f"Document {self.index_name} is not ready!")
|
||||
|
||||
def _get_embedding_size(self):
|
||||
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
|
||||
|
||||
def list_collections(self):
|
||||
"""
|
||||
List the collections in the vector database.
|
||||
|
||||
Returns:
|
||||
List[str] | The list of collections.
|
||||
"""
|
||||
return self.db.list_collection_names()
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
overwrite: bool = False,
|
||||
get_or_create: bool = True,
|
||||
) -> Collection:
|
||||
"""
|
||||
Create a collection in the vector database and create a vector search index in the collection.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
|
||||
get_or_create: bool | Whether to get or create the collection. Default is True
|
||||
"""
|
||||
if overwrite:
|
||||
self.delete_collection(collection_name)
|
||||
|
||||
if collection_name not in self.db.list_collection_names():
|
||||
# Create a new collection
|
||||
coll = self.db.create_collection(collection_name)
|
||||
self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
|
||||
return coll
|
||||
|
||||
if get_or_create:
|
||||
# The collection already exists, return it.
|
||||
coll = self.db[collection_name]
|
||||
self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
|
||||
return coll
|
||||
else:
|
||||
# get_or_create is False and the collection already exists, raise an error.
|
||||
raise ValueError(f"Collection {collection_name} already exists.")
|
||||
|
||||
def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None:
|
||||
"""
|
||||
Creates a vector search index on the specified collection in MongoDB.
|
||||
|
||||
Args:
|
||||
MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
|
||||
collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None.
|
||||
"""
|
||||
if not self._is_index_ready(collection, index_name):
|
||||
self.create_vector_search_index(collection, index_name)
|
||||
|
||||
def get_collection(self, collection_name: str = None) -> Collection:
|
||||
"""
|
||||
Get the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection. Default is None. If None, return the
|
||||
current active collection.
|
||||
|
||||
Returns:
|
||||
Collection | The collection object.
|
||||
"""
|
||||
if collection_name is None:
|
||||
if self.active_collection is None:
|
||||
raise ValueError("No collection is specified.")
|
||||
else:
|
||||
logger.debug(
|
||||
f"No collection is specified. Using current active collection {self.active_collection.name}."
|
||||
)
|
||||
else:
|
||||
self.active_collection = self.db[collection_name]
|
||||
|
||||
return self.active_collection
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""
|
||||
Delete the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
"""
|
||||
for index in self.db[collection_name].list_search_indexes():
|
||||
self.db[collection_name].drop_search_index(index["name"])
|
||||
if self._wait_until_index_ready:
|
||||
self._wait_for_index(self.db[collection_name], index["name"], "delete")
|
||||
return self.db[collection_name].drop()
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
collection: Collection,
|
||||
index_name: Union[str, None] = "vector_index",
|
||||
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
|
||||
) -> None:
|
||||
"""Create a vector search index in the collection.
|
||||
|
||||
Args:
|
||||
collection: An existing Collection in the Atlas Database.
|
||||
index_name: Vector Search Index name.
|
||||
similarity: Algorithm used for measuring vector similarity.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
search_index_model = SearchIndexModel(
|
||||
definition={
|
||||
"fields": [
|
||||
{
|
||||
"type": "vector",
|
||||
"numDimensions": self.dimensions,
|
||||
"path": "embedding",
|
||||
"similarity": similarity,
|
||||
},
|
||||
]
|
||||
},
|
||||
name=index_name,
|
||||
type="vectorSearch",
|
||||
)
|
||||
# Create the search index
|
||||
try:
|
||||
collection.create_search_index(model=search_index_model)
|
||||
if self._wait_until_index_ready:
|
||||
self._wait_for_index(collection, index_name, "create")
|
||||
logger.debug(f"Search index {index_name} created successfully.")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error creating search index: {e}. \n"
|
||||
f"Your client must be connected to an Atlas cluster. "
|
||||
f"You may have to manually create a Collection and Search Index "
|
||||
f"if you are on a free/shared cluster."
|
||||
)
|
||||
raise e
|
||||
|
||||
def insert_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
collection_name: str = None,
|
||||
upsert: bool = False,
|
||||
batch_size=DEFAULT_INSERT_BATCH_SIZE,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Insert Documents and Vector Embeddings into the collection of the vector database.
|
||||
|
||||
For large numbers of Documents, insertion is performed in batches.
|
||||
|
||||
Args:
|
||||
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
upsert: bool | Whether to update the document if it exists. Default is False.
|
||||
batch_size: Number of documents to be inserted in each batch
|
||||
"""
|
||||
if not docs:
|
||||
logger.info("No documents to insert.")
|
||||
return
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
if upsert:
|
||||
self.update_docs(docs, collection.name, upsert=True)
|
||||
else:
|
||||
# Sanity checking the first document
|
||||
if docs[0].get("content") is None:
|
||||
raise ValueError("The document content is required.")
|
||||
if docs[0].get("id") is None:
|
||||
raise ValueError("The document id is required.")
|
||||
|
||||
input_ids = set()
|
||||
result_ids = set()
|
||||
id_batch = []
|
||||
text_batch = []
|
||||
metadata_batch = []
|
||||
size = 0
|
||||
i = 0
|
||||
for doc in docs:
|
||||
id = doc["id"]
|
||||
text = doc["content"]
|
||||
metadata = doc.get("metadata", {})
|
||||
id_batch.append(id)
|
||||
text_batch.append(text)
|
||||
metadata_batch.append(metadata)
|
||||
id_size = 1 if isinstance(id, int) else len(id)
|
||||
size += len(text) + len(metadata) + id_size
|
||||
if (i + 1) % batch_size == 0 or size >= 47_000_000:
|
||||
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
|
||||
input_ids.update(id_batch)
|
||||
id_batch = []
|
||||
text_batch = []
|
||||
metadata_batch = []
|
||||
size = 0
|
||||
i += 1
|
||||
if text_batch:
|
||||
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
|
||||
input_ids.update(id_batch)
|
||||
|
||||
if result_ids != input_ids:
|
||||
logger.warning(
|
||||
"Possible data corruption. "
|
||||
"input_ids not in result_ids: {in_diff}.\n"
|
||||
"result_ids not in input_ids: {out_diff}".format(
|
||||
in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
|
||||
)
|
||||
)
|
||||
if self._wait_until_document_ready and docs:
|
||||
self._wait_for_document(collection, self.index_name, docs[-1])
|
||||
|
||||
def _insert_batch(
|
||||
self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
|
||||
) -> Set[ItemID]:
|
||||
"""Compute embeddings for and insert a batch of Documents into the Collection.
|
||||
|
||||
For performance reasons, we chose to call self.embedding_function just once,
|
||||
with the hopefully small tradeoff of having recreating Document dicts.
|
||||
|
||||
Args:
|
||||
collection: MongoDB Collection
|
||||
texts: List of the main contents of each document
|
||||
metadatas: List of metadata mappings
|
||||
ids: List of ids. Note that these are stored as _id in Collection.
|
||||
|
||||
Returns:
|
||||
List of ids inserted.
|
||||
"""
|
||||
n_texts = len(texts)
|
||||
if n_texts == 0:
|
||||
return []
|
||||
# Embed and create the documents
|
||||
embeddings = self.embedding_function(texts).tolist()
|
||||
assert (
|
||||
len(embeddings) == n_texts
|
||||
), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."
|
||||
to_insert = [
|
||||
{"_id": i, "content": t, "metadata": m, "embedding": e}
|
||||
for i, t, m, e in zip(ids, texts, metadatas, embeddings)
|
||||
]
|
||||
# insert the documents in MongoDB Atlas
|
||||
insert_result = collection.insert_many(to_insert) # type: ignore
|
||||
return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
|
||||
|
||||
def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
|
||||
"""Update documents, including their embeddings, in the Collection.
|
||||
|
||||
Optionally allow upsert as kwarg.
|
||||
|
||||
Uses deepcopy to avoid changing docs.
|
||||
|
||||
Args:
|
||||
docs: List[Document] | A list of documents.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
|
||||
"""
|
||||
|
||||
n_docs = len(docs)
|
||||
logger.info(f"Preparing to embed and update {n_docs=}")
|
||||
# Compute the embeddings
|
||||
embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist()
|
||||
# Prepare the updates
|
||||
all_updates = []
|
||||
for i in range(n_docs):
|
||||
doc = deepcopy(docs[i])
|
||||
doc["embedding"] = embeddings[i]
|
||||
doc["_id"] = doc.pop("id")
|
||||
|
||||
all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))
|
||||
# Perform update in bulk
|
||||
collection = self.get_collection(collection_name)
|
||||
result = collection.bulk_write(all_updates)
|
||||
|
||||
if self._wait_until_document_ready and docs:
|
||||
self._wait_for_document(collection, self.index_name, docs[-1])
|
||||
|
||||
# Log a result summary
|
||||
logger.info(
|
||||
"Matched: %s, Modified: %s, Upserted: %s",
|
||||
result.matched_count,
|
||||
result.modified_count,
|
||||
result.upserted_count,
|
||||
)
|
||||
|
||||
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs):
|
||||
"""
|
||||
Delete documents from the collection of the vector database.
|
||||
|
||||
Args:
|
||||
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
"""
|
||||
collection = self.get_collection(collection_name)
|
||||
return collection.delete_many({"_id": {"$in": ids}})
|
||||
|
||||
def get_docs_by_ids(
|
||||
self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve documents from the collection of the vector database based on the ids.
|
||||
|
||||
Args:
|
||||
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
include: List[str] | The fields to include.
|
||||
If None, will include ["metadata", "content"], ids will always be included.
|
||||
Basically, use include to choose whether to include embedding and metadata
|
||||
kwargs: dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document] | The results.
|
||||
"""
|
||||
if include is None:
|
||||
include_fields = {"_id": 1, "content": 1, "metadata": 1}
|
||||
else:
|
||||
include_fields = {k: 1 for k in set(include).union({"_id"})}
|
||||
collection = self.get_collection(collection_name)
|
||||
if ids is not None:
|
||||
docs = collection.find({"_id": {"$in": ids}}, include_fields)
|
||||
# Return with _id field from Collection into id for Document
|
||||
return with_id_rename(docs)
|
||||
else:
|
||||
docs = collection.find({}, include_fields)
|
||||
# Return with _id field from Collection into id for Document
|
||||
return with_id_rename(docs)
|
||||
|
||||
def retrieve_docs(
|
||||
self,
|
||||
queries: List[str],
|
||||
collection_name: str = None,
|
||||
n_results: int = 10,
|
||||
distance_threshold: float = -1,
|
||||
**kwargs,
|
||||
) -> QueryResults:
|
||||
"""
|
||||
Retrieve documents from the collection of the vector database based on the queries.
|
||||
|
||||
Args:
|
||||
queries: List[str] | A list of queries. Each query is a string.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
n_results: int | The number of relevant documents to return. Default is 10.
|
||||
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
|
||||
returned. Don't filter with it if < 0. Default is -1.
|
||||
kwargs: Dict | Additional keyword arguments. Ones of importance follow:
|
||||
oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
|
||||
It determines the number of nearest neighbor candidates to consider during the search phase.
|
||||
A higher value leads to more accuracy, but is slower. Default is 10
|
||||
|
||||
Returns:
|
||||
QueryResults | For each query string, a list of nearest documents and their scores.
|
||||
"""
|
||||
collection = self.get_collection(collection_name)
|
||||
# Trivial case of an empty collection
|
||||
if collection.count_documents({}) == 0:
|
||||
return []
|
||||
|
||||
logger.debug(f"Using index: {self.index_name}")
|
||||
results = []
|
||||
for query_text in queries:
|
||||
# Compute embedding vector from semantic query
|
||||
logger.debug(f"Query: {query_text}")
|
||||
query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
|
||||
# Find documents with similar vectors using the specified index
|
||||
query_result = _vector_search(
|
||||
query_vector,
|
||||
n_results,
|
||||
collection,
|
||||
self.index_name,
|
||||
distance_threshold,
|
||||
**kwargs,
|
||||
oversampling_factor=kwargs.get("oversampling_factor", 10),
|
||||
)
|
||||
# Change each _id key to id. with_id_rename, but with (doc, score) tuples
|
||||
results.append(
|
||||
[({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result]
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def _vector_search(
|
||||
embedding_vector: List[float],
|
||||
n_results: int,
|
||||
collection: Collection,
|
||||
index_name: str,
|
||||
distance_threshold: float = -1.0,
|
||||
oversampling_factor=10,
|
||||
include_embedding=False,
|
||||
) -> List[Tuple[Dict, float]]:
|
||||
"""Core $vectorSearch Aggregation pipeline.
|
||||
|
||||
Args:
|
||||
embedding_vector: Embedding vector of semantic query
|
||||
n_results: Number of documents to return. Defaults to 4.
|
||||
collection: MongoDB Collection with vector index
|
||||
index_name: Name of the vector index
|
||||
distance_threshold: Only distance measures smaller than this will be returned.
|
||||
Don't filter with it if 1 < x < 0. Default is -1.
|
||||
oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
|
||||
It determines the number of nearest neighbor candidates to consider during the search phase.
|
||||
A higher value leads to more accuracy, but is slower. Default = 10
|
||||
|
||||
Returns:
|
||||
List of tuples of length n_results from Collection.
|
||||
Each tuple contains a document dict and a score.
|
||||
"""
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$vectorSearch": {
|
||||
"index": index_name,
|
||||
"limit": n_results,
|
||||
"numCandidates": n_results * oversampling_factor,
|
||||
"queryVector": embedding_vector,
|
||||
"path": "embedding",
|
||||
}
|
||||
},
|
||||
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
|
||||
]
|
||||
if distance_threshold >= 0.0:
|
||||
similarity_threshold = 1.0 - distance_threshold
|
||||
pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}})
|
||||
|
||||
if not include_embedding:
|
||||
pipeline.append({"$project": {"embedding": 0}})
|
||||
|
||||
logger.debug("pipeline: %s", pipeline)
|
||||
agg = collection.aggregate(pipeline)
|
||||
return [(doc, doc.pop("score")) for doc in agg]
|
|
@ -0,0 +1,591 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using RetrieveChat Powered by MongoDB Atlas for Retrieve Augmented Code Generation and Question Answering\n",
|
||||
"\n",
|
||||
"AutoGen offers conversable agents powered by LLM, tool or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n",
|
||||
"Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
|
||||
"\n",
|
||||
"RetrieveChat is a conversational system for retrieval-augmented code generation and question answering. In this notebook, we demonstrate how to utilize RetrieveChat to generate code and answer questions based on customized documentations that are not present in the LLM's training dataset. RetrieveChat uses the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`, which is similar to the usage of `AssistantAgent` and `UserProxyAgent` in other notebooks (e.g., [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)). Essentially, `RetrieveAssistantAgent` and `RetrieveUserProxyAgent` implement a different auto-reply mechanism corresponding to the RetrieveChat prompts.\n",
|
||||
"\n",
|
||||
"## Table of Contents\n",
|
||||
"We'll demonstrate six examples of using RetrieveChat for code generation and question answering:\n",
|
||||
"\n",
|
||||
"- [Example 1: Generate code based off docstrings w/o human feedback](#example-1)\n",
|
||||
"\n",
|
||||
"````{=mdx}\n",
|
||||
":::info Requirements\n",
|
||||
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install pyautogen[retrievechat-mongodb] flaml[automl]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"For more information, please refer to the [installation guide](/docs/installation/).\n",
|
||||
":::\n",
|
||||
"````\n",
|
||||
"\n",
|
||||
"Ensure you have a MongoDB Atlas instance with Cluster Tier >= M30."
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set your API Endpoint\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"models to use: ['gpt-3.5-turbo-0125']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import autogen\n",
|
||||
"from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n",
|
||||
"from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n",
|
||||
"\n",
|
||||
"# Accepted file formats for that can be stored in\n",
|
||||
"# a vector database instance\n",
|
||||
"from autogen.retrieve_utils import TEXT_FORMATS\n",
|
||||
"\n",
|
||||
"config_list = [{\"model\": \"gpt-3.5-turbo-0125\", \"api_key\": os.environ[\"OPENAI_API_KEY\"], \"api_type\": \"openai\"}]\n",
|
||||
"assert len(config_list) > 0\n",
|
||||
"print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"````{=mdx}\n",
|
||||
":::tip\n",
|
||||
"Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).\n",
|
||||
":::\n",
|
||||
"````\n",
|
||||
"\n",
|
||||
"## Construct agents for RetrieveChat\n",
|
||||
"\n",
|
||||
"We start by initializing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to \"You are a helpful assistant.\" for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.message_generator` to combine the instructions and a retrieval augmented generation task for an initial prompt to be sent to the LLM assistant."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accepted file formats for `docs_path`:\n",
|
||||
"['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Accepted file formats for `docs_path`:\")\n",
|
||||
"print(TEXT_FORMATS)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n",
|
||||
"assistant = RetrieveAssistantAgent(\n",
|
||||
" name=\"assistant\",\n",
|
||||
" system_message=\"You are a helpful assistant.\",\n",
|
||||
" llm_config={\n",
|
||||
" \"timeout\": 600,\n",
|
||||
" \"cache_seed\": 42,\n",
|
||||
" \"config_list\": config_list,\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# 2. create the RetrieveUserProxyAgent instance named \"ragproxyagent\"\n",
|
||||
"# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n",
|
||||
"# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default,\n",
|
||||
"# it is set to None, which works only if the collection is already created.\n",
|
||||
"# `task` indicates the kind of task we're working on. In this example, it's a `code` task.\n",
|
||||
"# `chunk_token_size` is the chunk token size for the retrieve chat. By default, it is set to `max_tokens * 0.6`, here we set it to 2000.\n",
|
||||
"# `custom_text_types` is a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.\n",
|
||||
"# This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.\n",
|
||||
"# In this example, we set it to [\"non-existent-type\"] to only process markdown files. Since no \"non-existent-type\" files are included in the `websit/docs`,\n",
|
||||
"# no files there will be processed. However, the explicitly included urls will still be processed.\n",
|
||||
"# **NOTE** Upon the first time adding in the documents, initial query may be slower due to index creation and document indexing time\n",
|
||||
"ragproxyagent = RetrieveUserProxyAgent(\n",
|
||||
" name=\"ragproxyagent\",\n",
|
||||
" human_input_mode=\"NEVER\",\n",
|
||||
" max_consecutive_auto_reply=3,\n",
|
||||
" retrieve_config={\n",
|
||||
" \"task\": \"code\",\n",
|
||||
" \"docs_path\": [\n",
|
||||
" \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n",
|
||||
" \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n",
|
||||
" os.path.join(os.path.abspath(\"\"), \"..\", \"website\", \"docs\"),\n",
|
||||
" ],\n",
|
||||
" \"custom_text_types\": [\"non-existent-type\"],\n",
|
||||
" \"chunk_token_size\": 2000,\n",
|
||||
" \"model\": config_list[0][\"model\"],\n",
|
||||
" \"vector_db\": \"mongodb\", # MongoDB Atlas database\n",
|
||||
" \"collection_name\": \"demo_collection\",\n",
|
||||
" \"db_config\": {\n",
|
||||
" \"connection_string\": os.environ[\"MONGODB_URI\"], # MongoDB Atlas connection string\n",
|
||||
" \"database_name\": \"test_db\", # MongoDB Atlas database\n",
|
||||
" \"index_name\": \"vector_index\",\n",
|
||||
" \"wait_until_index_ready\": 120.0, # Setting to wait 120 seconds or until index is constructed before querying\n",
|
||||
" \"wait_until_document_ready\": 120.0, # Setting to wait 120 seconds or until document is properly indexed after insertion/update\n",
|
||||
" },\n",
|
||||
" \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n",
|
||||
" \"overwrite\": False, # set to True if you want to overwrite an existing collection, each overwrite will force a index creation and reupload of documents\n",
|
||||
" },\n",
|
||||
" code_execution_config=False, # set to False if you don't want to execute the code\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Example 1\n",
|
||||
"\n",
|
||||
"[Back to top](#table-of-contents)\n",
|
||||
"\n",
|
||||
"Use RetrieveChat to help generate sample code and automatically run the code and fix errors if there is any.\n",
|
||||
"\n",
|
||||
"Problem: Which API should I use if I want to use FLAML for a classification task and I want to train the model in 30 seconds. Use spark to parallel the training. Force cancel jobs if time limit is reached."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-07-25 13:47:30,700 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `demo_collection`.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Trying to create collection.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2024-07-25 13:47:31,048 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n",
|
||||
"2024-07-25 13:47:31,051 - autogen.agentchat.contrib.vectordb.mongodb - INFO - No documents to insert.\u001b[0m\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
|
||||
"\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n",
|
||||
"\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n",
|
||||
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n",
|
||||
"context provided by the user.\n",
|
||||
"If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n",
|
||||
"For code generation, you must obey the following rules:\n",
|
||||
"Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n",
|
||||
"Rule 2. You must follow the formats below to write your code:\n",
|
||||
"```language\n",
|
||||
"# your code\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"User's question is: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n",
|
||||
"\n",
|
||||
"Context is: # Integrate - Spark\n",
|
||||
"\n",
|
||||
"FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n",
|
||||
"\n",
|
||||
"- Use Spark ML estimators for AutoML.\n",
|
||||
"- Use Spark to run training in parallel spark jobs.\n",
|
||||
"\n",
|
||||
"## Spark ML Estimators\n",
|
||||
"\n",
|
||||
"FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
|
||||
"\n",
|
||||
"### Data\n",
|
||||
"\n",
|
||||
"For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n",
|
||||
"\n",
|
||||
"This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n",
|
||||
"\n",
|
||||
"This function also accepts optional arguments `index_col` and `default_index_type`.\n",
|
||||
"\n",
|
||||
"- `index_col` is the column name to use as the index, default is None.\n",
|
||||
"- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n",
|
||||
"\n",
|
||||
"Here is an example code snippet for Spark Data:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import pandas as pd\n",
|
||||
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
|
||||
"\n",
|
||||
"# Creating a dictionary\n",
|
||||
"data = {\n",
|
||||
" \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
|
||||
" \"Age_Years\": [20, 15, 10, 7, 25],\n",
|
||||
" \"Price\": [100000, 200000, 300000, 240000, 120000],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# Creating a pandas DataFrame\n",
|
||||
"dataframe = pd.DataFrame(data)\n",
|
||||
"label = \"Price\"\n",
|
||||
"\n",
|
||||
"# Convert to pandas-on-spark dataframe\n",
|
||||
"psdf = to_pandas_on_spark(dataframe)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n",
|
||||
"\n",
|
||||
"Here is an example of how to use it:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"from pyspark.ml.feature import VectorAssembler\n",
|
||||
"\n",
|
||||
"columns = psdf.columns\n",
|
||||
"feature_cols = [col for col in columns if col != label]\n",
|
||||
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
|
||||
"psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
|
||||
"\n",
|
||||
"### Estimators\n",
|
||||
"\n",
|
||||
"#### Model List\n",
|
||||
"\n",
|
||||
"- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n",
|
||||
"\n",
|
||||
"#### Usage\n",
|
||||
"\n",
|
||||
"First, prepare your data in the required format as described in the previous section.\n",
|
||||
"\n",
|
||||
"By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n",
|
||||
"\n",
|
||||
"Here is an example code snippet using SparkML models in AutoML:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import flaml\n",
|
||||
"\n",
|
||||
"# prepare your data in pandas-on-spark format as we previously mentioned\n",
|
||||
"\n",
|
||||
"automl = flaml.AutoML()\n",
|
||||
"settings = {\n",
|
||||
" \"time_budget\": 30,\n",
|
||||
" \"metric\": \"r2\",\n",
|
||||
" \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n",
|
||||
" \"task\": \"regression\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"automl.fit(\n",
|
||||
" dataframe=psdf,\n",
|
||||
" label=label,\n",
|
||||
" **settings,\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n",
|
||||
"\n",
|
||||
"## Parallel Spark Jobs\n",
|
||||
"\n",
|
||||
"You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n",
|
||||
"\n",
|
||||
"Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
|
||||
"\n",
|
||||
"All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n",
|
||||
"\n",
|
||||
"- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n",
|
||||
"- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n",
|
||||
"- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n",
|
||||
"\n",
|
||||
"An example code snippet for using parallel Spark jobs:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import flaml\n",
|
||||
"\n",
|
||||
"automl_experiment = flaml.AutoML()\n",
|
||||
"automl_settings = {\n",
|
||||
" \"time_budget\": 30,\n",
|
||||
" \"metric\": \"r2\",\n",
|
||||
" \"task\": \"regression\",\n",
|
||||
" \"n_concurrent_trials\": 2,\n",
|
||||
" \"use_spark\": True,\n",
|
||||
" \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"automl.fit(\n",
|
||||
" dataframe=dataframe,\n",
|
||||
" label=label,\n",
|
||||
" **automl_settings,\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n",
|
||||
"# Research\n",
|
||||
"\n",
|
||||
"For technical details, please check our research publications.\n",
|
||||
"\n",
|
||||
"- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wang2021flaml,\n",
|
||||
" title={FLAML: A Fast and Lightweight AutoML Library},\n",
|
||||
" author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={MLSys},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wu2021cfo,\n",
|
||||
" title={Frugal Optimization for Cost-related Hyperparameters},\n",
|
||||
" author={Qingyun Wu and Chi Wang and Silu Huang},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={AAAI},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wang2021blendsearch,\n",
|
||||
" title={Economical Hyperparameter Optimization With Blended Search Strategy},\n",
|
||||
" author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={ICLR},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{liuwang2021hpolm,\n",
|
||||
" title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n",
|
||||
" author={Susan Xueqing Liu and Chi Wang},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={ACL},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wu2021chacha,\n",
|
||||
" title={ChaCha for Online AutoML},\n",
|
||||
" author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={ICML},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wuwang2021fairautoml,\n",
|
||||
" title={Fair AutoML},\n",
|
||||
" author={Qingyun Wu and Chi Wang},\n",
|
||||
" year={2021},\n",
|
||||
" booktitle={ArXiv preprint arXiv:2111.06495},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{kayaliwang2022default,\n",
|
||||
" title={Mining Robust Default Configurations for Resource-constrained AutoML},\n",
|
||||
" author={Moe Kayali and Chi Wang},\n",
|
||||
" year={2022},\n",
|
||||
" booktitle={ArXiv preprint arXiv:2202.09927},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{zhang2023targeted,\n",
|
||||
" title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n",
|
||||
" author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n",
|
||||
" booktitle={International Conference on Learning Representations},\n",
|
||||
" year={2023},\n",
|
||||
" url={https://openreview.net/forum?id=0Ij9_q567Ma},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wang2023EcoOptiGen,\n",
|
||||
" title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n",
|
||||
" author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n",
|
||||
" year={2023},\n",
|
||||
" booktitle={ArXiv preprint arXiv:2303.04673},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n",
|
||||
"\n",
|
||||
"```bibtex\n",
|
||||
"@inproceedings{wu2023empirical,\n",
|
||||
" title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n",
|
||||
" author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n",
|
||||
" year={2023},\n",
|
||||
" booktitle={ArXiv preprint arXiv:2306.01337},\n",
|
||||
"}\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to ragproxyagent):\n",
|
||||
"\n",
|
||||
"To use FLAML to perform a classification task and use Spark for parallel training with a timeout of 30 seconds and force canceling jobs if the time limit is reached, you can follow the below code snippet:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"import flaml\n",
|
||||
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
|
||||
"from pyspark.ml.feature import VectorAssembler\n",
|
||||
"\n",
|
||||
"# Prepare your data in pandas-on-spark format\n",
|
||||
"data = {\n",
|
||||
" \"feature1\": [val1, val2, val3, val4],\n",
|
||||
" \"feature2\": [val5, val6, val7, val8],\n",
|
||||
" \"target\": [class1, class2, class1, class2],\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"dataframe = pd.DataFrame(data)\n",
|
||||
"label = \"target\"\n",
|
||||
"psdf = to_pandas_on_spark(dataframe)\n",
|
||||
"\n",
|
||||
"# Prepare your features using VectorAssembler\n",
|
||||
"columns = psdf.columns\n",
|
||||
"feature_cols = [col for col in columns if col != label]\n",
|
||||
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
|
||||
"psdf = featurizer.transform(psdf)\n",
|
||||
"\n",
|
||||
"# Define AutoML settings and fit the model\n",
|
||||
"automl = flaml.AutoML()\n",
|
||||
"settings = {\n",
|
||||
" \"time_budget\": 30,\n",
|
||||
" \"metric\": \"accuracy\",\n",
|
||||
" \"task\": \"classification\",\n",
|
||||
" \"estimator_list\": [\"lgbm_spark\"], # Optional\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"automl.fit(\n",
|
||||
" dataframe=psdf,\n",
|
||||
" label=label,\n",
|
||||
" **settings,\n",
|
||||
")\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"In the code:\n",
|
||||
"- Replace `val1, val2, ..., class1, class2` with your actual data values.\n",
|
||||
"- Ensure the features and target columns are correctly specified in the data dictionary.\n",
|
||||
"- Set the `time_budget` parameter to 30 to limit the training time.\n",
|
||||
"- The `force_cancel` parameter is set to `True` to force cancel Spark jobs if the time limit is exceeded.\n",
|
||||
"\n",
|
||||
"Make sure to adapt the code to your specific dataset and requirements.\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[33massistant\u001b[0m (to ragproxyagent):\n",
|
||||
"\n",
|
||||
"UPDATE CONTEXT\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n",
|
||||
"\u001b[32mUpdating context and resetting conversation.\u001b[0m\n",
|
||||
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
|
||||
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
|
||||
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
|
||||
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
|
||||
"\u001b[32mNo more context, will terminate.\u001b[0m\n",
|
||||
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
|
||||
"\n",
|
||||
"TERMINATE\n",
|
||||
"\n",
|
||||
"--------------------------------------------------------------------------------\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# reset the assistant. Always reset the assistant before starting a new conversation.\n",
|
||||
"assistant.reset()\n",
|
||||
"\n",
|
||||
"# given a problem, we use the ragproxyagent to generate a prompt to be sent to the assistant as the initial message.\n",
|
||||
"# the assistant receives the message and generates a response. The response will be sent back to the ragproxyagent for processing.\n",
|
||||
"# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n",
|
||||
"# With human-in-loop, the conversation will continue until the user says \"exit\".\n",
|
||||
"code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n",
|
||||
"chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"front_matter": {
|
||||
"description": "Explore the use of AutoGen's RetrieveChat for tasks like code generation from docstrings, answering complex questions with human feedback, and exploiting features like Update Context, custom prompts, and few-shot learning.",
|
||||
"tags": [
|
||||
"RAG"
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
},
|
||||
"skip_test": "Requires interactive usage"
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
1
setup.py
1
setup.py
|
@ -72,6 +72,7 @@ extra_require = {
|
|||
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
|
||||
"retrievechat": retrieve_chat,
|
||||
"retrievechat-pgvector": retrieve_chat_pgvector,
|
||||
"retrievechat-mongodb": [*retrieve_chat, "pymongo>=4.0.0"],
|
||||
"retrievechat-qdrant": [*retrieve_chat, "qdrant_client", "fastembed>=0.3.1"],
|
||||
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub", "pysqlite3"],
|
||||
"teachable": ["chromadb"],
|
||||
|
|
|
@ -0,0 +1,402 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from time import monotonic, sleep
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from autogen.agentchat.contrib.vectordb.base import Document
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
import sentence_transformers
|
||||
|
||||
from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB
|
||||
except ImportError:
|
||||
# To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]")
|
||||
pytest.skip(allow_module_level=True)
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.errors import OperationFailure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/?directConnection=true")
|
||||
MONGODB_DATABASE = os.environ.get("DATABASE", "autogen_test_db")
|
||||
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "autogen_test_vectorstore")
|
||||
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "vector_index")
|
||||
|
||||
RETRIES = 10
|
||||
DELAY = 2
|
||||
TIMEOUT = 120.0
|
||||
|
||||
|
||||
def _wait_for_predicate(predicate, err, timeout=TIMEOUT, interval=DELAY):
|
||||
"""Generic to block until the predicate returns true
|
||||
|
||||
Args:
|
||||
predicate (Callable[, bool]): A function that returns a boolean value
|
||||
err (str): Error message to raise if nothing occurs
|
||||
timeout (float, optional): Length of time to wait for predicate. Defaults to TIMEOUT.
|
||||
interval (float, optional): Interval to check predicate. Defaults to DELAY.
|
||||
|
||||
Raises:
|
||||
TimeoutError: _description_
|
||||
"""
|
||||
start = monotonic()
|
||||
while not predicate():
|
||||
if monotonic() - start > TIMEOUT:
|
||||
raise TimeoutError(err)
|
||||
sleep(DELAY)
|
||||
|
||||
|
||||
def _delete_search_indexes(collection: Collection, wait=True):
|
||||
"""Deletes all indexes in a collection
|
||||
|
||||
Args:
|
||||
collection (pymongo.Collection): MongoDB Collection Abstraction
|
||||
"""
|
||||
for index in collection.list_search_indexes():
|
||||
try:
|
||||
collection.drop_search_index(index["name"])
|
||||
except OperationFailure:
|
||||
# Delete already issued
|
||||
pass
|
||||
if wait:
|
||||
_wait_for_predicate(lambda: not list(collection.list_search_indexes()), "Not all collections deleted")
|
||||
|
||||
|
||||
def _empty_collections_and_delete_indexes(database, collections=None, wait=True):
|
||||
"""Empty all collections within the database and remove indexes
|
||||
|
||||
Args:
|
||||
database (pymongo.Database): MongoDB Database Abstraction
|
||||
"""
|
||||
for collection_name in collections or database.list_collection_names():
|
||||
_delete_search_indexes(database[collection_name], wait)
|
||||
database[collection_name].drop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
"""VectorDB setup and teardown, including collections and search indexes"""
|
||||
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||||
_empty_collections_and_delete_indexes(database)
|
||||
vectorstore = MongoDBAtlasVectorDB(
|
||||
connection_string=MONGODB_URI,
|
||||
database_name=MONGODB_DATABASE,
|
||||
wait_until_index_ready=TIMEOUT,
|
||||
overwrite=True,
|
||||
)
|
||||
yield vectorstore
|
||||
_empty_collections_and_delete_indexes(database)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_documents() -> List[Document]:
|
||||
"""Note mix of integers and strings as ids"""
|
||||
return [
|
||||
Document(id=1, content="Dogs are tough.", metadata={"a": 1}),
|
||||
Document(id=2, content="Cats have fluff.", metadata={"b": 1}),
|
||||
Document(id="1", content="What is a sandwich?", metadata={"c": 1}),
|
||||
Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_with_indexed_clxn(collection_name):
|
||||
"""VectorDB with a collection created immediately"""
|
||||
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||||
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
|
||||
vectorstore = MongoDBAtlasVectorDB(
|
||||
connection_string=MONGODB_URI,
|
||||
database_name=MONGODB_DATABASE,
|
||||
wait_until_index_ready=TIMEOUT,
|
||||
collection_name=collection_name,
|
||||
overwrite=True,
|
||||
)
|
||||
yield vectorstore, vectorstore.db[collection_name]
|
||||
_empty_collections_and_delete_indexes(database, [collection_name])
|
||||
|
||||
|
||||
_COLLECTION_NAMING_CACHE = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def collection_name():
|
||||
collection_id = random.randint(0, 100)
|
||||
while collection_id in _COLLECTION_NAMING_CACHE:
|
||||
collection_id = random.randint(0, 100)
|
||||
_COLLECTION_NAMING_CACHE.append(collection_id)
|
||||
|
||||
return f"{MONGODB_COLLECTION}_{collection_id}"
|
||||
|
||||
|
||||
def test_create_collection(db, collection_name):
|
||||
"""
|
||||
def create_collection(collection_name: str,
|
||||
overwrite: bool = False) -> Collection
|
||||
Create a collection in the vector database.
|
||||
- Case 1. if the collection does not exist, create the collection.
|
||||
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
|
||||
- Case 3. the collection exists and overwrite is False return the existing collection.
|
||||
- Case 4. the collection exists and overwrite is False and get_or_create is False, raise a ValueError
|
||||
"""
|
||||
collection_case_1 = db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
assert collection_case_1.name == collection_name
|
||||
|
||||
collection_case_2 = db.create_collection(
|
||||
collection_name=collection_name,
|
||||
overwrite=True,
|
||||
)
|
||||
assert collection_case_2.name == collection_name
|
||||
|
||||
collection_case_3 = db.create_collection(
|
||||
collection_name=collection_name,
|
||||
)
|
||||
assert collection_case_3.name == collection_name
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False)
|
||||
|
||||
|
||||
def test_get_collection(db, collection_name):
|
||||
with pytest.raises(ValueError):
|
||||
db.get_collection()
|
||||
|
||||
collection_created = db.create_collection(collection_name)
|
||||
assert isinstance(collection_created, Collection)
|
||||
assert collection_created.name == collection_name
|
||||
|
||||
collection_got = db.get_collection(collection_name)
|
||||
assert collection_got.name == collection_created.name
|
||||
assert collection_got.name == db.active_collection.name
|
||||
|
||||
|
||||
def test_delete_collection(db, collection_name):
|
||||
assert collection_name not in db.list_collections()
|
||||
collection = db.create_collection(collection_name)
|
||||
assert collection_name in db.list_collections()
|
||||
db.delete_collection(collection.name)
|
||||
assert collection_name not in db.list_collections()
|
||||
|
||||
|
||||
def test_insert_docs(db, collection_name, example_documents):
|
||||
# Test that there's an active collection
|
||||
with pytest.raises(ValueError) as exc:
|
||||
db.insert_docs(example_documents)
|
||||
assert "No collection is specified" in str(exc.value)
|
||||
|
||||
# Test upsert
|
||||
db.insert_docs(example_documents, collection_name, upsert=True)
|
||||
|
||||
# Create a collection
|
||||
db.delete_collection(collection_name)
|
||||
collection = db.create_collection(collection_name)
|
||||
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=collection_name)
|
||||
found = list(collection.find({}))
|
||||
assert len(found) == len(example_documents)
|
||||
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
|
||||
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
|
||||
# Check ids
|
||||
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
|
||||
# Check embedding lengths
|
||||
assert len(found[0]["embedding"]) == 384
|
||||
|
||||
|
||||
def test_update_docs(db_with_indexed_clxn, example_documents):
|
||||
db, collection = db_with_indexed_clxn
|
||||
# Use update_docs to insert new documents
|
||||
db.update_docs(example_documents, collection.name, upsert=True)
|
||||
# Test that no changes were made to example_documents
|
||||
assert set(example_documents[0].keys()) == {"id", "content", "metadata"}
|
||||
assert collection.count_documents({}) == len(example_documents)
|
||||
found = list(collection.find({}))
|
||||
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
|
||||
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
|
||||
assert all([isinstance(doc["embedding"][0], float) for doc in found])
|
||||
assert all([len(doc["embedding"]) == db.dimensions for doc in found])
|
||||
# Check ids
|
||||
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
|
||||
|
||||
# Update an *existing* Document
|
||||
updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10})
|
||||
db.update_docs([updated_doc], collection.name)
|
||||
assert collection.find_one({"_id": 1})["content"] == "Cats are tough."
|
||||
|
||||
# Upsert a *new* Document
|
||||
new_id = 3
|
||||
new_doc = Document(id=new_id, content="Cats are tough.")
|
||||
db.update_docs([new_doc], collection.name, upsert=True)
|
||||
assert collection.find_one({"_id": new_id})["content"] == "Cats are tough."
|
||||
|
||||
# Attempting to use update to insert a new doc
|
||||
# *without* setting upsert set to True
|
||||
# is a no-op in MongoDB. # TODO Confirm behaviour and autogen's preference.
|
||||
new_id = 4
|
||||
new_doc = Document(id=new_id, content="That is NOT a sandwich?")
|
||||
db.update_docs([new_doc], collection.name)
|
||||
assert collection.find_one({"_id": new_id}) is None
|
||||
|
||||
|
||||
def test_delete_docs(db_with_indexed_clxn, example_documents):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
# Delete the 1s
|
||||
db.delete_docs(ids=[1, "1"], collection_name=clxn.name)
|
||||
# Confirm just the 2s remain
|
||||
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}
|
||||
|
||||
|
||||
def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
|
||||
# Test without setting "include" kwarg
|
||||
docs = db.get_docs_by_ids(ids=[2, "2"], collection_name=clxn.name)
|
||||
assert len(docs) == 2
|
||||
assert all([doc["id"] in [2, "2"] for doc in docs])
|
||||
assert set(docs[0].keys()) == {"id", "content", "metadata"}
|
||||
|
||||
# Test with include
|
||||
docs = db.get_docs_by_ids(ids=[2], include=["content"], collection_name=clxn.name)
|
||||
assert len(docs) == 1
|
||||
assert set(docs[0].keys()) == {"id", "content"}
|
||||
|
||||
# Test with empty ids list
|
||||
docs = db.get_docs_by_ids(ids=[], include=["content"], collection_name=clxn.name)
|
||||
assert len(docs) == 0
|
||||
|
||||
# Test with empty ids list
|
||||
docs = db.get_docs_by_ids(ids=None, include=["content"], collection_name=clxn.name)
|
||||
assert len(docs) == 4
|
||||
|
||||
|
||||
def test_retrieve_docs_empty(db_with_indexed_clxn):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == []
|
||||
|
||||
|
||||
def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
# Empty list of queries returns empty list of results
|
||||
results = db.retrieve_docs(queries=[], collection_name=clxn.name, n_results=2)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_retrieve_docs(db_with_indexed_clxn, example_documents):
|
||||
"""Begin testing Atlas Vector Search
|
||||
NOTE: Indexing may take some time, so we must be patient on the first query.
|
||||
We have the wait_until_index_ready flag to ensure index is created and ready
|
||||
Immediately adding documents and then querying is only standard for testing
|
||||
"""
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
|
||||
n_results = 2 # Number of closest docs to return
|
||||
|
||||
def results_ready():
|
||||
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||||
return len(results[0]) == n_results
|
||||
|
||||
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||||
|
||||
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||||
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||||
assert all(["embedding" not in doc[0] for doc in results[0]])
|
||||
|
||||
|
||||
def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents):
|
||||
"""Begin testing Atlas Vector Search
|
||||
NOTE: Indexing may take some time, so we must be patient on the first query.
|
||||
We have the wait_until_index_ready flag to ensure index is created and ready
|
||||
Immediately adding documents and then querying is only standard for testing
|
||||
"""
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
|
||||
n_results = 2 # Number of closest docs to return
|
||||
|
||||
def results_ready():
|
||||
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
|
||||
return len(results[0]) == n_results
|
||||
|
||||
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||||
|
||||
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results, include_embedding=True)
|
||||
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||||
assert all(["embedding" in doc[0] for doc in results[0]])
|
||||
|
||||
|
||||
def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
n_results = 2 # Number of closest docs to return
|
||||
|
||||
queries = ["Some good pets", "What kind of Sandwich?"]
|
||||
|
||||
def results_ready():
|
||||
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
|
||||
return all([len(res) == n_results for res in results])
|
||||
|
||||
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||||
|
||||
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=2)
|
||||
|
||||
assert len(results) == len(queries)
|
||||
assert all([len(res) == n_results for res in results])
|
||||
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
|
||||
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}
|
||||
|
||||
|
||||
def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
|
||||
db, clxn = db_with_indexed_clxn
|
||||
# Insert example documents
|
||||
db.insert_docs(example_documents, collection_name=clxn.name)
|
||||
|
||||
n_results = 2 # Number of closest docs to return
|
||||
queries = ["Cats"]
|
||||
|
||||
def results_ready():
|
||||
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
|
||||
return len(results[0]) == n_results
|
||||
|
||||
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
|
||||
|
||||
# Distance Threshold of .3 means that the score must be .7 or greater
|
||||
# only one result should be that value
|
||||
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results, distance_threshold=0.3)
|
||||
assert len(results[0]) == 1
|
||||
assert all([doc[1] >= 0.7 for doc in results[0]])
|
||||
|
||||
|
||||
def test_wait_until_document_ready(collection_name, example_documents):
|
||||
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
|
||||
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
|
||||
try:
|
||||
vectorstore = MongoDBAtlasVectorDB(
|
||||
connection_string=MONGODB_URI,
|
||||
database_name=MONGODB_DATABASE,
|
||||
wait_until_index_ready=TIMEOUT,
|
||||
collection_name=collection_name,
|
||||
overwrite=True,
|
||||
wait_until_document_ready=TIMEOUT,
|
||||
)
|
||||
vectorstore.insert_docs(example_documents)
|
||||
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4)
|
||||
finally:
|
||||
_empty_collections_and_delete_indexes(database, [collection_name])
|
Loading…
Reference in New Issue