+mdb atlas vectordb [clean_final] (#3000)

* +mdb atlas

* Update test/agentchat/contrib/vectordb/test_mongodb.py

Co-authored-by: HRUSHIKESH DOKALA <96101829+Hk669@users.noreply.github.com>

* update test_mongodb.py; we dont need to do the assert .collection_name vs .name

* Try fix mongodb service

* Try fix mongodb service

* Update username and password

* Update autogen/agentchat/contrib/vectordb/mongodb.py

* closer --- but im not super thrilled about the solution...

* PYTHON-4506 Expanded tests and simplified vector search pipelines

* Update mongodb.py

* Update mongodb.py - Casey

* search_index_magic

index_name change; keeping track of lucene indexes is tricky

* Fix format

* Fix tests

* hacking trying to figure this out

* Streamline checks for indexes in construction and restructure tests

* Add tests for score_threshold, embedding inclusion, and multiple query tests

* refactored create_collection to meet base object requirements

* lint

* change the localhost port to 27017

* add test to check that no embedding is there unless explicitly provided

* Update logger

* Add test get docs with ids=None

* Rename and update notebook

* have index management include waiting behaviors

* Adds further optional waits or users and tests. Cleans up upsert.

* ensure the embedding size for multiple embedding inputs is equal to dimensions

* fix up tests and add configuration to ensure documents and indexes are READY for querying

* fix import failure

* adjust typing for 3.9

* fix up the notebook output

* changed language to communicate time taken on first init_chat call

* replace environment variable usage

---------

Co-authored-by: Fabian Valle <fabian.valle-simmons@mongodb.com>
Co-authored-by: HRUSHIKESH DOKALA <96101829+Hk669@users.noreply.github.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
Co-authored-by: Casey Clements <casey.clements@mongodb.com>
Co-authored-by: Jib <jib.adegunloye@mongodb.com>
Co-authored-by: Jib <Jibzade@gmail.com>
Co-authored-by: Cozypet <yanhan860711@gmail.com>
This commit is contained in:
Fabian Valle 2024-07-25 19:11:19 -04:00 committed by GitHub
parent 1bd2124ba4
commit f9295c4c39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1561 additions and 2 deletions

View File

@ -87,6 +87,10 @@ jobs:
--health-retries 5
ports:
- 5432:5432
mongodb:
image: mongodb/mongodb-atlas-local:latest
ports:
- 27017:27017
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
@ -104,6 +108,9 @@ jobs:
- name: Install pgvector when on linux
run: |
pip install -e .[retrievechat-pgvector]
- name: Install mongodb when on linux
run: |
pip install -e .[retrievechat-mongodb]
- name: Install unstructured when python-version is 3.9 and on linux
if: matrix.python-version == '3.9'
run: |

View File

@ -186,7 +186,8 @@ class VectorDB(Protocol):
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included.
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
depending on the implementation.
kwargs: dict | Additional keyword arguments.
Returns:
@ -200,7 +201,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@ -222,6 +223,10 @@ class VectorDBFactory:
from .pgvectordb import PGVectorDB
return PGVectorDB(**kwargs)
if db_type.lower() in ["mdb", "mongodb", "atlas"]:
from .mongodb import MongoDBAtlasVectorDB
return MongoDBAtlasVectorDB(**kwargs)
if db_type.lower() in ["qdrant", "qdrantdb"]:
from .qdrant import QdrantVectorDB

View File

@ -0,0 +1,553 @@
from copy import deepcopy
from time import monotonic, sleep
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Set, Tuple, Union
import numpy as np
from pymongo import MongoClient, UpdateOne, errors
from pymongo.collection import Collection
from pymongo.driver_info import DriverInfo
from pymongo.operations import SearchIndexModel
from sentence_transformers import SentenceTransformer
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger
logger = get_logger(__name__)
DEFAULT_INSERT_BATCH_SIZE = 100_000
_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
_DELAY = 0.5
def with_id_rename(docs: Iterable) -> List[Dict[str, Any]]:
"""Utility changes _id field from Collection into id for Document."""
return [{**{k: v for k, v in d.items() if k != "_id"}, "id": d["_id"]} for d in docs]
class MongoDBAtlasVectorDB(VectorDB):
"""
A Collection object for MongoDB.
"""
def __init__(
self,
connection_string: str = "",
database_name: str = "vector_db",
embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
collection_name: str = None,
index_name: str = "vector_index",
overwrite: bool = False,
wait_until_index_ready: float = None,
wait_until_document_ready: float = None,
):
"""
Initialize the vector database.
Args:
connection_string: str | The MongoDB connection string to connect to. Default is ''.
database_name: str | The name of the database. Default is 'vector_db'.
embedding_function: Callable | The embedding function used to generate the vector representation.
collection_name: str | The name of the collection to create for this vector database
Defaults to None
index_name: str | Index name for the vector database, defaults to 'vector_index'
overwrite: bool = False
wait_until_index_ready: float | None | Blocking call to wait until the
database indexes are ready. None, the default, means no wait.
wait_until_document_ready: float | None | Blocking call to wait until the
database indexes are ready. None, the default, means no wait.
"""
self.embedding_function = embedding_function
self.index_name = index_name
self._wait_until_index_ready = wait_until_index_ready
self._wait_until_document_ready = wait_until_document_ready
# This will get the model dimension size by computing the embeddings dimensions
self.dimensions = self._get_embedding_size()
try:
self.client = MongoClient(connection_string, driver=DriverInfo(name="autogen"))
self.client.admin.command("ping")
logger.debug("Successfully created MongoClient")
except errors.ServerSelectionTimeoutError as err:
raise ConnectionError("Could not connect to MongoDB server") from err
self.db = self.client[database_name]
logger.debug(f"Atlas Database name: {self.db.name}")
if collection_name:
self.active_collection = self.create_collection(collection_name, overwrite)
else:
self.active_collection = None
def _is_index_ready(self, collection: Collection, index_name: str):
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY
Args:
collection (Collection): MongoDB Collection to for the search indexes
index_name (str): Vector Search Index name
Returns:
bool : True if the index is present and READY false otherwise
"""
for index in collection.list_search_indexes(index_name):
if index["type"] == "vectorSearch" and index["status"] == "READY":
return True
return False
def _wait_for_index(self, collection: Collection, index_name: str, action: str = "create"):
"""Waits for the index action to be completed. Otherwise throws a TimeoutError.
Timeout set on instantiation.
action: "create" or "delete"
"""
assert action in ["create", "delete"], f"{action=} must be create or delete."
start = monotonic()
while monotonic() - start < self._wait_until_index_ready:
if action == "create" and self._is_index_ready(collection, index_name):
return
elif action == "delete" and len(list(collection.list_search_indexes())) == 0:
return
sleep(_DELAY)
raise TimeoutError(f"Index {self.index_name} is not ready!")
def _wait_for_document(self, collection: Collection, index_name: str, doc: Document):
start = monotonic()
while monotonic() - start < self._wait_until_document_ready:
query_result = _vector_search(
embedding_vector=np.array(self.embedding_function(doc["content"])).tolist(),
n_results=1,
collection=collection,
index_name=index_name,
)
if query_result and query_result[0][0]["_id"] == doc["id"]:
return
sleep(_DELAY)
raise TimeoutError(f"Document {self.index_name} is not ready!")
def _get_embedding_size(self):
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
def list_collections(self):
"""
List the collections in the vector database.
Returns:
List[str] | The list of collections.
"""
return self.db.list_collection_names()
def create_collection(
self,
collection_name: str,
overwrite: bool = False,
get_or_create: bool = True,
) -> Collection:
"""
Create a collection in the vector database and create a vector search index in the collection.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get or create the collection. Default is True
"""
if overwrite:
self.delete_collection(collection_name)
if collection_name not in self.db.list_collection_names():
# Create a new collection
coll = self.db.create_collection(collection_name)
self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
return coll
if get_or_create:
# The collection already exists, return it.
coll = self.db[collection_name]
self.create_index_if_not_exists(index_name=self.index_name, collection=coll)
return coll
else:
# get_or_create is False and the collection already exists, raise an error.
raise ValueError(f"Collection {collection_name} already exists.")
def create_index_if_not_exists(self, index_name: str = "vector_index", collection: Collection = None) -> None:
"""
Creates a vector search index on the specified collection in MongoDB.
Args:
MONGODB_INDEX (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
collection (Collection, optional): The MongoDB collection to create the index on. Defaults to None.
"""
if not self._is_index_ready(collection, index_name):
self.create_vector_search_index(collection, index_name)
def get_collection(self, collection_name: str = None) -> Collection:
"""
Get the collection from the vector database.
Args:
collection_name: str | The name of the collection. Default is None. If None, return the
current active collection.
Returns:
Collection | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.debug(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
self.active_collection = self.db[collection_name]
return self.active_collection
def delete_collection(self, collection_name: str) -> None:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
"""
for index in self.db[collection_name].list_search_indexes():
self.db[collection_name].drop_search_index(index["name"])
if self._wait_until_index_ready:
self._wait_for_index(self.db[collection_name], index["name"], "delete")
return self.db[collection_name].drop()
def create_vector_search_index(
self,
collection: Collection,
index_name: Union[str, None] = "vector_index",
similarity: Literal["euclidean", "cosine", "dotProduct"] = "cosine",
) -> None:
"""Create a vector search index in the collection.
Args:
collection: An existing Collection in the Atlas Database.
index_name: Vector Search Index name.
similarity: Algorithm used for measuring vector similarity.
kwargs: Additional keyword arguments.
Returns:
None
"""
search_index_model = SearchIndexModel(
definition={
"fields": [
{
"type": "vector",
"numDimensions": self.dimensions,
"path": "embedding",
"similarity": similarity,
},
]
},
name=index_name,
type="vectorSearch",
)
# Create the search index
try:
collection.create_search_index(model=search_index_model)
if self._wait_until_index_ready:
self._wait_for_index(collection, index_name, "create")
logger.debug(f"Search index {index_name} created successfully.")
except Exception as e:
logger.error(
f"Error creating search index: {e}. \n"
f"Your client must be connected to an Atlas cluster. "
f"You may have to manually create a Collection and Search Index "
f"if you are on a free/shared cluster."
)
raise e
def insert_docs(
self,
docs: List[Document],
collection_name: str = None,
upsert: bool = False,
batch_size=DEFAULT_INSERT_BATCH_SIZE,
**kwargs,
) -> None:
"""Insert Documents and Vector Embeddings into the collection of the vector database.
For large numbers of Documents, insertion is performed in batches.
Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
batch_size: Number of documents to be inserted in each batch
"""
if not docs:
logger.info("No documents to insert.")
return
collection = self.get_collection(collection_name)
if upsert:
self.update_docs(docs, collection.name, upsert=True)
else:
# Sanity checking the first document
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")
input_ids = set()
result_ids = set()
id_batch = []
text_batch = []
metadata_batch = []
size = 0
i = 0
for doc in docs:
id = doc["id"]
text = doc["content"]
metadata = doc.get("metadata", {})
id_batch.append(id)
text_batch.append(text)
metadata_batch.append(metadata)
id_size = 1 if isinstance(id, int) else len(id)
size += len(text) + len(metadata) + id_size
if (i + 1) % batch_size == 0 or size >= 47_000_000:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch))
input_ids.update(id_batch)
id_batch = []
text_batch = []
metadata_batch = []
size = 0
i += 1
if text_batch:
result_ids.update(self._insert_batch(collection, text_batch, metadata_batch, id_batch)) # type: ignore
input_ids.update(id_batch)
if result_ids != input_ids:
logger.warning(
"Possible data corruption. "
"input_ids not in result_ids: {in_diff}.\n"
"result_ids not in input_ids: {out_diff}".format(
in_diff=input_ids.difference(result_ids), out_diff=result_ids.difference(input_ids)
)
)
if self._wait_until_document_ready and docs:
self._wait_for_document(collection, self.index_name, docs[-1])
def _insert_batch(
self, collection: Collection, texts: List[str], metadatas: List[Mapping[str, Any]], ids: List[ItemID]
) -> Set[ItemID]:
"""Compute embeddings for and insert a batch of Documents into the Collection.
For performance reasons, we chose to call self.embedding_function just once,
with the hopefully small tradeoff of having recreating Document dicts.
Args:
collection: MongoDB Collection
texts: List of the main contents of each document
metadatas: List of metadata mappings
ids: List of ids. Note that these are stored as _id in Collection.
Returns:
List of ids inserted.
"""
n_texts = len(texts)
if n_texts == 0:
return []
# Embed and create the documents
embeddings = self.embedding_function(texts).tolist()
assert (
len(embeddings) == n_texts
), f"The number of embeddings produced by self.embedding_function ({len(embeddings)} does not match the number of texts provided to it ({n_texts})."
to_insert = [
{"_id": i, "content": t, "metadata": m, "embedding": e}
for i, t, m, e in zip(ids, texts, metadatas, embeddings)
]
# insert the documents in MongoDB Atlas
insert_result = collection.insert_many(to_insert) # type: ignore
return insert_result.inserted_ids # TODO Remove this. Replace by log like update_docs
def update_docs(self, docs: List[Document], collection_name: str = None, **kwargs: Any) -> None:
"""Update documents, including their embeddings, in the Collection.
Optionally allow upsert as kwarg.
Uses deepcopy to avoid changing docs.
Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.
kwargs: Any | Use upsert=True` to insert documents whose ids are not present in collection.
"""
n_docs = len(docs)
logger.info(f"Preparing to embed and update {n_docs=}")
# Compute the embeddings
embeddings: list[list[float]] = self.embedding_function([doc["content"] for doc in docs]).tolist()
# Prepare the updates
all_updates = []
for i in range(n_docs):
doc = deepcopy(docs[i])
doc["embedding"] = embeddings[i]
doc["_id"] = doc.pop("id")
all_updates.append(UpdateOne({"_id": doc["_id"]}, {"$set": doc}, upsert=kwargs.get("upsert", False)))
# Perform update in bulk
collection = self.get_collection(collection_name)
result = collection.bulk_write(all_updates)
if self._wait_until_document_ready and docs:
self._wait_for_document(collection, self.index_name, docs[-1])
# Log a result summary
logger.info(
"Matched: %s, Modified: %s, Upserted: %s",
result.matched_count,
result.modified_count,
result.upserted_count,
)
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs):
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: str | The name of the collection. Default is None.
"""
collection = self.get_collection(collection_name)
return collection.delete_many({"_id": {"$in": ids}})
def get_docs_by_ids(
self, ids: List[ItemID] = None, collection_name: str = None, include: List[str] = None, **kwargs
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include.
If None, will include ["metadata", "content"], ids will always be included.
Basically, use include to choose whether to include embedding and metadata
kwargs: dict | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
if include is None:
include_fields = {"_id": 1, "content": 1, "metadata": 1}
else:
include_fields = {k: 1 for k in set(include).union({"_id"})}
collection = self.get_collection(collection_name)
if ids is not None:
docs = collection.find({"_id": {"$in": ids}}, include_fields)
# Return with _id field from Collection into id for Document
return with_id_rename(docs)
else:
docs = collection.find({}, include_fields)
# Return with _id field from Collection into id for Document
return with_id_rename(docs)
def retrieve_docs(
self,
queries: List[str],
collection_name: str = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: str | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict | Additional keyword arguments. Ones of importance follow:
oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
It determines the number of nearest neighbor candidates to consider during the search phase.
A higher value leads to more accuracy, but is slower. Default is 10
Returns:
QueryResults | For each query string, a list of nearest documents and their scores.
"""
collection = self.get_collection(collection_name)
# Trivial case of an empty collection
if collection.count_documents({}) == 0:
return []
logger.debug(f"Using index: {self.index_name}")
results = []
for query_text in queries:
# Compute embedding vector from semantic query
logger.debug(f"Query: {query_text}")
query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
# Find documents with similar vectors using the specified index
query_result = _vector_search(
query_vector,
n_results,
collection,
self.index_name,
distance_threshold,
**kwargs,
oversampling_factor=kwargs.get("oversampling_factor", 10),
)
# Change each _id key to id. with_id_rename, but with (doc, score) tuples
results.append(
[({**{k: v for k, v in d[0].items() if k != "_id"}, "id": d[0]["_id"]}, d[1]) for d in query_result]
)
return results
def _vector_search(
embedding_vector: List[float],
n_results: int,
collection: Collection,
index_name: str,
distance_threshold: float = -1.0,
oversampling_factor=10,
include_embedding=False,
) -> List[Tuple[Dict, float]]:
"""Core $vectorSearch Aggregation pipeline.
Args:
embedding_vector: Embedding vector of semantic query
n_results: Number of documents to return. Defaults to 4.
collection: MongoDB Collection with vector index
index_name: Name of the vector index
distance_threshold: Only distance measures smaller than this will be returned.
Don't filter with it if 1 < x < 0. Default is -1.
oversampling_factor: int | This times n_results is 'ef' in the HNSW algorithm.
It determines the number of nearest neighbor candidates to consider during the search phase.
A higher value leads to more accuracy, but is slower. Default = 10
Returns:
List of tuples of length n_results from Collection.
Each tuple contains a document dict and a score.
"""
pipeline = [
{
"$vectorSearch": {
"index": index_name,
"limit": n_results,
"numCandidates": n_results * oversampling_factor,
"queryVector": embedding_vector,
"path": "embedding",
}
},
{"$set": {"score": {"$meta": "vectorSearchScore"}}},
]
if distance_threshold >= 0.0:
similarity_threshold = 1.0 - distance_threshold
pipeline.append({"$match": {"score": {"$gte": similarity_threshold}}})
if not include_embedding:
pipeline.append({"$project": {"embedding": 0}})
logger.debug("pipeline: %s", pipeline)
agg = collection.aggregate(pipeline)
return [(doc, doc.pop("score")) for doc in agg]

View File

@ -0,0 +1,591 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using RetrieveChat Powered by MongoDB Atlas for Retrieve Augmented Code Generation and Question Answering\n",
"\n",
"AutoGen offers conversable agents powered by LLM, tool or human, which can be used to perform tasks collectively via automated chat. This framework allows tool use and human participation through multi-agent conversation.\n",
"Please find documentation about this feature [here](https://microsoft.github.io/autogen/docs/Use-Cases/agent_chat).\n",
"\n",
"RetrieveChat is a conversational system for retrieval-augmented code generation and question answering. In this notebook, we demonstrate how to utilize RetrieveChat to generate code and answer questions based on customized documentations that are not present in the LLM's training dataset. RetrieveChat uses the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`, which is similar to the usage of `AssistantAgent` and `UserProxyAgent` in other notebooks (e.g., [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_auto_feedback_from_code_execution.ipynb)). Essentially, `RetrieveAssistantAgent` and `RetrieveUserProxyAgent` implement a different auto-reply mechanism corresponding to the RetrieveChat prompts.\n",
"\n",
"## Table of Contents\n",
"We'll demonstrate six examples of using RetrieveChat for code generation and question answering:\n",
"\n",
"- [Example 1: Generate code based off docstrings w/o human feedback](#example-1)\n",
"\n",
"````{=mdx}\n",
":::info Requirements\n",
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
"\n",
"```bash\n",
"pip install pyautogen[retrievechat-mongodb] flaml[automl]\n",
"```\n",
"\n",
"For more information, please refer to the [installation guide](/docs/installation/).\n",
":::\n",
"````\n",
"\n",
"Ensure you have a MongoDB Atlas instance with Cluster Tier >= M30."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set your API Endpoint\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"models to use: ['gpt-3.5-turbo-0125']\n"
]
}
],
"source": [
"import json\n",
"import os\n",
"\n",
"import autogen\n",
"from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent\n",
"from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent\n",
"\n",
"# Accepted file formats for that can be stored in\n",
"# a vector database instance\n",
"from autogen.retrieve_utils import TEXT_FORMATS\n",
"\n",
"config_list = [{\"model\": \"gpt-3.5-turbo-0125\", \"api_key\": os.environ[\"OPENAI_API_KEY\"], \"api_type\": \"openai\"}]\n",
"assert len(config_list) > 0\n",
"print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"````{=mdx}\n",
":::tip\n",
"Learn more about configuring LLMs for agents [here](/docs/topics/llm_configuration).\n",
":::\n",
"````\n",
"\n",
"## Construct agents for RetrieveChat\n",
"\n",
"We start by initializing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to \"You are a helpful assistant.\" for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.message_generator` to combine the instructions and a retrieval augmented generation task for an initial prompt to be sent to the LLM assistant."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accepted file formats for `docs_path`:\n",
"['txt', 'json', 'csv', 'tsv', 'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml', 'pdf']\n"
]
}
],
"source": [
"print(\"Accepted file formats for `docs_path`:\")\n",
"print(TEXT_FORMATS)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n",
"assistant = RetrieveAssistantAgent(\n",
" name=\"assistant\",\n",
" system_message=\"You are a helpful assistant.\",\n",
" llm_config={\n",
" \"timeout\": 600,\n",
" \"cache_seed\": 42,\n",
" \"config_list\": config_list,\n",
" },\n",
")\n",
"\n",
"# 2. create the RetrieveUserProxyAgent instance named \"ragproxyagent\"\n",
"# By default, the human_input_mode is \"ALWAYS\", which means the agent will ask for human input at every step. We set it to \"NEVER\" here.\n",
"# `docs_path` is the path to the docs directory. It can also be the path to a single file, or the url to a single file. By default,\n",
"# it is set to None, which works only if the collection is already created.\n",
"# `task` indicates the kind of task we're working on. In this example, it's a `code` task.\n",
"# `chunk_token_size` is the chunk token size for the retrieve chat. By default, it is set to `max_tokens * 0.6`, here we set it to 2000.\n",
"# `custom_text_types` is a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.\n",
"# This only applies to files under the directories in `docs_path`. Explicitly included files and urls will be chunked regardless of their types.\n",
"# In this example, we set it to [\"non-existent-type\"] to only process markdown files. Since no \"non-existent-type\" files are included in the `websit/docs`,\n",
"# no files there will be processed. However, the explicitly included urls will still be processed.\n",
"# **NOTE** Upon the first time adding in the documents, initial query may be slower due to index creation and document indexing time\n",
"ragproxyagent = RetrieveUserProxyAgent(\n",
" name=\"ragproxyagent\",\n",
" human_input_mode=\"NEVER\",\n",
" max_consecutive_auto_reply=3,\n",
" retrieve_config={\n",
" \"task\": \"code\",\n",
" \"docs_path\": [\n",
" \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md\",\n",
" \"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md\",\n",
" os.path.join(os.path.abspath(\"\"), \"..\", \"website\", \"docs\"),\n",
" ],\n",
" \"custom_text_types\": [\"non-existent-type\"],\n",
" \"chunk_token_size\": 2000,\n",
" \"model\": config_list[0][\"model\"],\n",
" \"vector_db\": \"mongodb\", # MongoDB Atlas database\n",
" \"collection_name\": \"demo_collection\",\n",
" \"db_config\": {\n",
" \"connection_string\": os.environ[\"MONGODB_URI\"], # MongoDB Atlas connection string\n",
" \"database_name\": \"test_db\", # MongoDB Atlas database\n",
" \"index_name\": \"vector_index\",\n",
" \"wait_until_index_ready\": 120.0, # Setting to wait 120 seconds or until index is constructed before querying\n",
" \"wait_until_document_ready\": 120.0, # Setting to wait 120 seconds or until document is properly indexed after insertion/update\n",
" },\n",
" \"get_or_create\": True, # set to False if you don't want to reuse an existing collection\n",
" \"overwrite\": False, # set to True if you want to overwrite an existing collection, each overwrite will force a index creation and reupload of documents\n",
" },\n",
" code_execution_config=False, # set to False if you don't want to execute the code\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example 1\n",
"\n",
"[Back to top](#table-of-contents)\n",
"\n",
"Use RetrieveChat to help generate sample code and automatically run the code and fix errors if there is any.\n",
"\n",
"Problem: Which API should I use if I want to use FLAML for a classification task and I want to train the model in 30 seconds. Use spark to parallel the training. Force cancel jobs if time limit is reached."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-07-25 13:47:30,700 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - \u001b[32mUse the existing collection `demo_collection`.\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trying to create collection.\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-07-25 13:47:31,048 - autogen.agentchat.contrib.retrieve_user_proxy_agent - INFO - Found 2 chunks.\u001b[0m\n",
"2024-07-25 13:47:31,051 - autogen.agentchat.contrib.vectordb.mongodb - INFO - No documents to insert.\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
"\u001b[32mAdding content of doc bdfbc921 to context.\u001b[0m\n",
"\u001b[32mAdding content of doc 7968cf3c to context.\u001b[0m\n",
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
"\n",
"You're a retrieve augmented coding assistant. You answer user's questions based on your own knowledge and the\n",
"context provided by the user.\n",
"If you can't answer the question with or without the current context, you should reply exactly `UPDATE CONTEXT`.\n",
"For code generation, you must obey the following rules:\n",
"Rule 1. You MUST NOT install any packages because all the packages needed are already installed.\n",
"Rule 2. You must follow the formats below to write your code:\n",
"```language\n",
"# your code\n",
"```\n",
"\n",
"User's question is: How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\n",
"\n",
"Context is: # Integrate - Spark\n",
"\n",
"FLAML has integrated Spark for distributed training. There are two main aspects of integration with Spark:\n",
"\n",
"- Use Spark ML estimators for AutoML.\n",
"- Use Spark to run training in parallel spark jobs.\n",
"\n",
"## Spark ML Estimators\n",
"\n",
"FLAML integrates estimators based on Spark ML models. These models are trained in parallel using Spark, so we called them Spark estimators. To use these models, you first need to organize your data in the required format.\n",
"\n",
"### Data\n",
"\n",
"For Spark estimators, AutoML only consumes Spark data. FLAML provides a convenient function `to_pandas_on_spark` in the `flaml.automl.spark.utils` module to convert your data into a pandas-on-spark (`pyspark.pandas`) dataframe/series, which Spark estimators require.\n",
"\n",
"This utility function takes data in the form of a `pandas.Dataframe` or `pyspark.sql.Dataframe` and converts it into a pandas-on-spark dataframe. It also takes `pandas.Series` or `pyspark.sql.Dataframe` and converts it into a [pandas-on-spark](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/index.html) series. If you pass in a `pyspark.pandas.Dataframe`, it will not make any changes.\n",
"\n",
"This function also accepts optional arguments `index_col` and `default_index_type`.\n",
"\n",
"- `index_col` is the column name to use as the index, default is None.\n",
"- `default_index_type` is the default index type, default is \"distributed-sequence\". More info about default index type could be found on Spark official [documentation](https://spark.apache.org/docs/latest/api/python/user_guide/pandas_on_spark/options.html#default-index-type)\n",
"\n",
"Here is an example code snippet for Spark Data:\n",
"\n",
"```python\n",
"import pandas as pd\n",
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
"\n",
"# Creating a dictionary\n",
"data = {\n",
" \"Square_Feet\": [800, 1200, 1800, 1500, 850],\n",
" \"Age_Years\": [20, 15, 10, 7, 25],\n",
" \"Price\": [100000, 200000, 300000, 240000, 120000],\n",
"}\n",
"\n",
"# Creating a pandas DataFrame\n",
"dataframe = pd.DataFrame(data)\n",
"label = \"Price\"\n",
"\n",
"# Convert to pandas-on-spark dataframe\n",
"psdf = to_pandas_on_spark(dataframe)\n",
"```\n",
"\n",
"To use Spark ML models you need to format your data appropriately. Specifically, use [`VectorAssembler`](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.feature.VectorAssembler.html) to merge all feature columns into a single vector column.\n",
"\n",
"Here is an example of how to use it:\n",
"\n",
"```python\n",
"from pyspark.ml.feature import VectorAssembler\n",
"\n",
"columns = psdf.columns\n",
"feature_cols = [col for col in columns if col != label]\n",
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
"psdf = featurizer.transform(psdf.to_spark(index_col=\"index\"))[\"index\", \"features\"]\n",
"```\n",
"\n",
"Later in conducting the experiment, use your pandas-on-spark data like non-spark data and pass them using `X_train, y_train` or `dataframe, label`.\n",
"\n",
"### Estimators\n",
"\n",
"#### Model List\n",
"\n",
"- `lgbm_spark`: The class for fine-tuning Spark version LightGBM models, using [SynapseML](https://microsoft.github.io/SynapseML/docs/features/lightgbm/about/) API.\n",
"\n",
"#### Usage\n",
"\n",
"First, prepare your data in the required format as described in the previous section.\n",
"\n",
"By including the models you intend to try in the `estimators_list` argument to `flaml.automl`, FLAML will start trying configurations for these models. If your input is Spark data, FLAML will also use estimators with the `_spark` postfix by default, even if you haven't specified them.\n",
"\n",
"Here is an example code snippet using SparkML models in AutoML:\n",
"\n",
"```python\n",
"import flaml\n",
"\n",
"# prepare your data in pandas-on-spark format as we previously mentioned\n",
"\n",
"automl = flaml.AutoML()\n",
"settings = {\n",
" \"time_budget\": 30,\n",
" \"metric\": \"r2\",\n",
" \"estimator_list\": [\"lgbm_spark\"], # this setting is optional\n",
" \"task\": \"regression\",\n",
"}\n",
"\n",
"automl.fit(\n",
" dataframe=psdf,\n",
" label=label,\n",
" **settings,\n",
")\n",
"```\n",
"\n",
"[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/automl_bankrupt_synapseml.ipynb)\n",
"\n",
"## Parallel Spark Jobs\n",
"\n",
"You can activate Spark as the parallel backend during parallel tuning in both [AutoML](/docs/Use-Cases/Task-Oriented-AutoML#parallel-tuning) and [Hyperparameter Tuning](/docs/Use-Cases/Tune-User-Defined-Function#parallel-tuning), by setting the `use_spark` to `true`. FLAML will dispatch your job to the distributed Spark backend using [`joblib-spark`](https://github.com/joblib/joblib-spark).\n",
"\n",
"Please note that you should not set `use_spark` to `true` when applying AutoML and Tuning for Spark Data. This is because only SparkML models will be used for Spark Data in AutoML and Tuning. As SparkML models run in parallel, there is no need to distribute them with `use_spark` again.\n",
"\n",
"All the Spark-related arguments are stated below. These arguments are available in both Hyperparameter Tuning and AutoML:\n",
"\n",
"- `use_spark`: boolean, default=False | Whether to use spark to run the training in parallel spark jobs. This can be used to accelerate training on large models and large datasets, but will incur more overhead in time and thus slow down training in some cases. GPU training is not supported yet when use_spark is True. For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`.\n",
"- `n_concurrent_trials`: int, default=1 | The number of concurrent trials. When n_concurrent_trials > 1, FLAML performes parallel tuning.\n",
"- `force_cancel`: boolean, default=False | Whether to forcely cancel Spark jobs if the search time exceeded the time budget. Spark jobs include parallel tuning jobs and Spark-based model training jobs.\n",
"\n",
"An example code snippet for using parallel Spark jobs:\n",
"\n",
"```python\n",
"import flaml\n",
"\n",
"automl_experiment = flaml.AutoML()\n",
"automl_settings = {\n",
" \"time_budget\": 30,\n",
" \"metric\": \"r2\",\n",
" \"task\": \"regression\",\n",
" \"n_concurrent_trials\": 2,\n",
" \"use_spark\": True,\n",
" \"force_cancel\": True, # Activating the force_cancel option can immediately halt Spark jobs once they exceed the allocated time_budget.\n",
"}\n",
"\n",
"automl.fit(\n",
" dataframe=dataframe,\n",
" label=label,\n",
" **automl_settings,\n",
")\n",
"```\n",
"\n",
"[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/integrate_spark.ipynb)\n",
"# Research\n",
"\n",
"For technical details, please check our research publications.\n",
"\n",
"- [FLAML: A Fast and Lightweight AutoML Library](https://www.microsoft.com/en-us/research/publication/flaml-a-fast-and-lightweight-automl-library/). Chi Wang, Qingyun Wu, Markus Weimer, Erkang Zhu. MLSys 2021.\n",
"\n",
"```bibtex\n",
"@inproceedings{wang2021flaml,\n",
" title={FLAML: A Fast and Lightweight AutoML Library},\n",
" author={Chi Wang and Qingyun Wu and Markus Weimer and Erkang Zhu},\n",
" year={2021},\n",
" booktitle={MLSys},\n",
"}\n",
"```\n",
"\n",
"- [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.\n",
"\n",
"```bibtex\n",
"@inproceedings{wu2021cfo,\n",
" title={Frugal Optimization for Cost-related Hyperparameters},\n",
" author={Qingyun Wu and Chi Wang and Silu Huang},\n",
" year={2021},\n",
" booktitle={AAAI},\n",
"}\n",
"```\n",
"\n",
"- [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.\n",
"\n",
"```bibtex\n",
"@inproceedings{wang2021blendsearch,\n",
" title={Economical Hyperparameter Optimization With Blended Search Strategy},\n",
" author={Chi Wang and Qingyun Wu and Silu Huang and Amin Saied},\n",
" year={2021},\n",
" booktitle={ICLR},\n",
"}\n",
"```\n",
"\n",
"- [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://aclanthology.org/2021.acl-long.178.pdf). Susan Xueqing Liu, Chi Wang. ACL 2021.\n",
"\n",
"```bibtex\n",
"@inproceedings{liuwang2021hpolm,\n",
" title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},\n",
" author={Susan Xueqing Liu and Chi Wang},\n",
" year={2021},\n",
" booktitle={ACL},\n",
"}\n",
"```\n",
"\n",
"- [ChaCha for Online AutoML](https://www.microsoft.com/en-us/research/publication/chacha-for-online-automl/). Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. ICML 2021.\n",
"\n",
"```bibtex\n",
"@inproceedings{wu2021chacha,\n",
" title={ChaCha for Online AutoML},\n",
" author={Qingyun Wu and Chi Wang and John Langford and Paul Mineiro and Marco Rossi},\n",
" year={2021},\n",
" booktitle={ICML},\n",
"}\n",
"```\n",
"\n",
"- [Fair AutoML](https://arxiv.org/abs/2111.06495). Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2111.06495 (2021).\n",
"\n",
"```bibtex\n",
"@inproceedings{wuwang2021fairautoml,\n",
" title={Fair AutoML},\n",
" author={Qingyun Wu and Chi Wang},\n",
" year={2021},\n",
" booktitle={ArXiv preprint arXiv:2111.06495},\n",
"}\n",
"```\n",
"\n",
"- [Mining Robust Default Configurations for Resource-constrained AutoML](https://arxiv.org/abs/2202.09927). Moe Kayali, Chi Wang. ArXiv preprint arXiv:2202.09927 (2022).\n",
"\n",
"```bibtex\n",
"@inproceedings{kayaliwang2022default,\n",
" title={Mining Robust Default Configurations for Resource-constrained AutoML},\n",
" author={Moe Kayali and Chi Wang},\n",
" year={2022},\n",
" booktitle={ArXiv preprint arXiv:2202.09927},\n",
"}\n",
"```\n",
"\n",
"- [Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives](https://openreview.net/forum?id=0Ij9_q567Ma). Shaokun Zhang, Feiran Jia, Chi Wang, Qingyun Wu. ICLR 2023 (notable-top-5%).\n",
"\n",
"```bibtex\n",
"@inproceedings{zhang2023targeted,\n",
" title={Targeted Hyperparameter Optimization with Lexicographic Preferences Over Multiple Objectives},\n",
" author={Shaokun Zhang and Feiran Jia and Chi Wang and Qingyun Wu},\n",
" booktitle={International Conference on Learning Representations},\n",
" year={2023},\n",
" url={https://openreview.net/forum?id=0Ij9_q567Ma},\n",
"}\n",
"```\n",
"\n",
"- [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).\n",
"\n",
"```bibtex\n",
"@inproceedings{wang2023EcoOptiGen,\n",
" title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},\n",
" author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},\n",
" year={2023},\n",
" booktitle={ArXiv preprint arXiv:2303.04673},\n",
"}\n",
"```\n",
"\n",
"- [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337). Yiran Wu, Feiran Jia, Shaokun Zhang, Hangyu Li, Erkang Zhu, Yue Wang, Yin Tat Lee, Richard Peng, Qingyun Wu, Chi Wang. ArXiv preprint arXiv:2306.01337 (2023).\n",
"\n",
"```bibtex\n",
"@inproceedings{wu2023empirical,\n",
" title={An Empirical Study on Challenging Math Problem Solving with GPT-4},\n",
" author={Yiran Wu and Feiran Jia and Shaokun Zhang and Hangyu Li and Erkang Zhu and Yue Wang and Yin Tat Lee and Richard Peng and Qingyun Wu and Chi Wang},\n",
" year={2023},\n",
" booktitle={ArXiv preprint arXiv:2306.01337},\n",
"}\n",
"```\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to ragproxyagent):\n",
"\n",
"To use FLAML to perform a classification task and use Spark for parallel training with a timeout of 30 seconds and force canceling jobs if the time limit is reached, you can follow the below code snippet:\n",
"\n",
"```python\n",
"import flaml\n",
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
"from pyspark.ml.feature import VectorAssembler\n",
"\n",
"# Prepare your data in pandas-on-spark format\n",
"data = {\n",
" \"feature1\": [val1, val2, val3, val4],\n",
" \"feature2\": [val5, val6, val7, val8],\n",
" \"target\": [class1, class2, class1, class2],\n",
"}\n",
"\n",
"dataframe = pd.DataFrame(data)\n",
"label = \"target\"\n",
"psdf = to_pandas_on_spark(dataframe)\n",
"\n",
"# Prepare your features using VectorAssembler\n",
"columns = psdf.columns\n",
"feature_cols = [col for col in columns if col != label]\n",
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
"psdf = featurizer.transform(psdf)\n",
"\n",
"# Define AutoML settings and fit the model\n",
"automl = flaml.AutoML()\n",
"settings = {\n",
" \"time_budget\": 30,\n",
" \"metric\": \"accuracy\",\n",
" \"task\": \"classification\",\n",
" \"estimator_list\": [\"lgbm_spark\"], # Optional\n",
"}\n",
"\n",
"automl.fit(\n",
" dataframe=psdf,\n",
" label=label,\n",
" **settings,\n",
")\n",
"```\n",
"\n",
"In the code:\n",
"- Replace `val1, val2, ..., class1, class2` with your actual data values.\n",
"- Ensure the features and target columns are correctly specified in the data dictionary.\n",
"- Set the `time_budget` parameter to 30 to limit the training time.\n",
"- The `force_cancel` parameter is set to `True` to force cancel Spark jobs if the time limit is exceeded.\n",
"\n",
"Make sure to adapt the code to your specific dataset and requirements.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33massistant\u001b[0m (to ragproxyagent):\n",
"\n",
"UPDATE CONTEXT\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[32mUpdating context and resetting conversation.\u001b[0m\n",
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
"VectorDB returns doc_ids: [['bdfbc921', '7968cf3c']]\n",
"\u001b[32mNo more context, will terminate.\u001b[0m\n",
"\u001b[33mragproxyagent\u001b[0m (to assistant):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"# reset the assistant. Always reset the assistant before starting a new conversation.\n",
"assistant.reset()\n",
"\n",
"# given a problem, we use the ragproxyagent to generate a prompt to be sent to the assistant as the initial message.\n",
"# the assistant receives the message and generates a response. The response will be sent back to the ragproxyagent for processing.\n",
"# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n",
"# With human-in-loop, the conversation will continue until the user says \"exit\".\n",
"code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n",
"chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)"
]
}
],
"metadata": {
"front_matter": {
"description": "Explore the use of AutoGen's RetrieveChat for tasks like code generation from docstrings, answering complex questions with human feedback, and exploiting features like Update Context, custom prompts, and few-shot learning.",
"tags": [
"RAG"
]
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
},
"skip_test": "Requires interactive usage"
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -72,6 +72,7 @@ extra_require = {
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": retrieve_chat,
"retrievechat-pgvector": retrieve_chat_pgvector,
"retrievechat-mongodb": [*retrieve_chat, "pymongo>=4.0.0"],
"retrievechat-qdrant": [*retrieve_chat, "qdrant_client", "fastembed>=0.3.1"],
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub", "pysqlite3"],
"teachable": ["chromadb"],

View File

@ -0,0 +1,402 @@
import logging
import os
import random
from time import monotonic, sleep
from typing import List
import pytest
from autogen.agentchat.contrib.vectordb.base import Document
try:
import pymongo
import sentence_transformers
from autogen.agentchat.contrib.vectordb.mongodb import MongoDBAtlasVectorDB
except ImportError:
# To display warning in pyproject.toml [tool.pytest.ini_options] set log_cli = true
logger = logging.getLogger(__name__)
logger.warning(f"skipping {__name__}. It requires one to pip install pymongo or the extra [retrievechat-mongodb]")
pytest.skip(allow_module_level=True)
from pymongo import MongoClient
from pymongo.collection import Collection
from pymongo.errors import OperationFailure
logger = logging.getLogger(__name__)
MONGODB_URI = os.environ.get("MONGODB_URI", "mongodb://localhost:27017/?directConnection=true")
MONGODB_DATABASE = os.environ.get("DATABASE", "autogen_test_db")
MONGODB_COLLECTION = os.environ.get("MONGODB_COLLECTION", "autogen_test_vectorstore")
MONGODB_INDEX = os.environ.get("MONGODB_INDEX", "vector_index")
RETRIES = 10
DELAY = 2
TIMEOUT = 120.0
def _wait_for_predicate(predicate, err, timeout=TIMEOUT, interval=DELAY):
"""Generic to block until the predicate returns true
Args:
predicate (Callable[, bool]): A function that returns a boolean value
err (str): Error message to raise if nothing occurs
timeout (float, optional): Length of time to wait for predicate. Defaults to TIMEOUT.
interval (float, optional): Interval to check predicate. Defaults to DELAY.
Raises:
TimeoutError: _description_
"""
start = monotonic()
while not predicate():
if monotonic() - start > TIMEOUT:
raise TimeoutError(err)
sleep(DELAY)
def _delete_search_indexes(collection: Collection, wait=True):
"""Deletes all indexes in a collection
Args:
collection (pymongo.Collection): MongoDB Collection Abstraction
"""
for index in collection.list_search_indexes():
try:
collection.drop_search_index(index["name"])
except OperationFailure:
# Delete already issued
pass
if wait:
_wait_for_predicate(lambda: not list(collection.list_search_indexes()), "Not all collections deleted")
def _empty_collections_and_delete_indexes(database, collections=None, wait=True):
"""Empty all collections within the database and remove indexes
Args:
database (pymongo.Database): MongoDB Database Abstraction
"""
for collection_name in collections or database.list_collection_names():
_delete_search_indexes(database[collection_name], wait)
database[collection_name].drop()
@pytest.fixture
def db():
"""VectorDB setup and teardown, including collections and search indexes"""
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
_empty_collections_and_delete_indexes(database)
vectorstore = MongoDBAtlasVectorDB(
connection_string=MONGODB_URI,
database_name=MONGODB_DATABASE,
wait_until_index_ready=TIMEOUT,
overwrite=True,
)
yield vectorstore
_empty_collections_and_delete_indexes(database)
@pytest.fixture
def example_documents() -> List[Document]:
"""Note mix of integers and strings as ids"""
return [
Document(id=1, content="Dogs are tough.", metadata={"a": 1}),
Document(id=2, content="Cats have fluff.", metadata={"b": 1}),
Document(id="1", content="What is a sandwich?", metadata={"c": 1}),
Document(id="2", content="A sandwich makes a great lunch.", metadata={"d": 1, "e": 2}),
]
@pytest.fixture
def db_with_indexed_clxn(collection_name):
"""VectorDB with a collection created immediately"""
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
vectorstore = MongoDBAtlasVectorDB(
connection_string=MONGODB_URI,
database_name=MONGODB_DATABASE,
wait_until_index_ready=TIMEOUT,
collection_name=collection_name,
overwrite=True,
)
yield vectorstore, vectorstore.db[collection_name]
_empty_collections_and_delete_indexes(database, [collection_name])
_COLLECTION_NAMING_CACHE = []
@pytest.fixture
def collection_name():
collection_id = random.randint(0, 100)
while collection_id in _COLLECTION_NAMING_CACHE:
collection_id = random.randint(0, 100)
_COLLECTION_NAMING_CACHE.append(collection_id)
return f"{MONGODB_COLLECTION}_{collection_id}"
def test_create_collection(db, collection_name):
"""
def create_collection(collection_name: str,
overwrite: bool = False) -> Collection
Create a collection in the vector database.
- Case 1. if the collection does not exist, create the collection.
- Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
- Case 3. the collection exists and overwrite is False return the existing collection.
- Case 4. the collection exists and overwrite is False and get_or_create is False, raise a ValueError
"""
collection_case_1 = db.create_collection(
collection_name=collection_name,
)
assert collection_case_1.name == collection_name
collection_case_2 = db.create_collection(
collection_name=collection_name,
overwrite=True,
)
assert collection_case_2.name == collection_name
collection_case_3 = db.create_collection(
collection_name=collection_name,
)
assert collection_case_3.name == collection_name
with pytest.raises(ValueError):
db.create_collection(collection_name=collection_name, overwrite=False, get_or_create=False)
def test_get_collection(db, collection_name):
with pytest.raises(ValueError):
db.get_collection()
collection_created = db.create_collection(collection_name)
assert isinstance(collection_created, Collection)
assert collection_created.name == collection_name
collection_got = db.get_collection(collection_name)
assert collection_got.name == collection_created.name
assert collection_got.name == db.active_collection.name
def test_delete_collection(db, collection_name):
assert collection_name not in db.list_collections()
collection = db.create_collection(collection_name)
assert collection_name in db.list_collections()
db.delete_collection(collection.name)
assert collection_name not in db.list_collections()
def test_insert_docs(db, collection_name, example_documents):
# Test that there's an active collection
with pytest.raises(ValueError) as exc:
db.insert_docs(example_documents)
assert "No collection is specified" in str(exc.value)
# Test upsert
db.insert_docs(example_documents, collection_name, upsert=True)
# Create a collection
db.delete_collection(collection_name)
collection = db.create_collection(collection_name)
# Insert example documents
db.insert_docs(example_documents, collection_name=collection_name)
found = list(collection.find({}))
assert len(found) == len(example_documents)
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
# Check embedding lengths
assert len(found[0]["embedding"]) == 384
def test_update_docs(db_with_indexed_clxn, example_documents):
db, collection = db_with_indexed_clxn
# Use update_docs to insert new documents
db.update_docs(example_documents, collection.name, upsert=True)
# Test that no changes were made to example_documents
assert set(example_documents[0].keys()) == {"id", "content", "metadata"}
assert collection.count_documents({}) == len(example_documents)
found = list(collection.find({}))
# Check that documents have correct fields, including "_id" and "embedding" but not "id"
assert all([set(doc.keys()) == {"_id", "content", "metadata", "embedding"} for doc in found])
assert all([isinstance(doc["embedding"][0], float) for doc in found])
assert all([len(doc["embedding"]) == db.dimensions for doc in found])
# Check ids
assert {doc["_id"] for doc in found} == {1, "1", 2, "2"}
# Update an *existing* Document
updated_doc = Document(id=1, content="Cats are tough.", metadata={"a": 10})
db.update_docs([updated_doc], collection.name)
assert collection.find_one({"_id": 1})["content"] == "Cats are tough."
# Upsert a *new* Document
new_id = 3
new_doc = Document(id=new_id, content="Cats are tough.")
db.update_docs([new_doc], collection.name, upsert=True)
assert collection.find_one({"_id": new_id})["content"] == "Cats are tough."
# Attempting to use update to insert a new doc
# *without* setting upsert set to True
# is a no-op in MongoDB. # TODO Confirm behaviour and autogen's preference.
new_id = 4
new_doc = Document(id=new_id, content="That is NOT a sandwich?")
db.update_docs([new_doc], collection.name)
assert collection.find_one({"_id": new_id}) is None
def test_delete_docs(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
# Delete the 1s
db.delete_docs(ids=[1, "1"], collection_name=clxn.name)
# Confirm just the 2s remain
assert {2, "2"} == {doc["_id"] for doc in clxn.find({})}
def test_get_docs_by_ids(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
# Test without setting "include" kwarg
docs = db.get_docs_by_ids(ids=[2, "2"], collection_name=clxn.name)
assert len(docs) == 2
assert all([doc["id"] in [2, "2"] for doc in docs])
assert set(docs[0].keys()) == {"id", "content", "metadata"}
# Test with include
docs = db.get_docs_by_ids(ids=[2], include=["content"], collection_name=clxn.name)
assert len(docs) == 1
assert set(docs[0].keys()) == {"id", "content"}
# Test with empty ids list
docs = db.get_docs_by_ids(ids=[], include=["content"], collection_name=clxn.name)
assert len(docs) == 0
# Test with empty ids list
docs = db.get_docs_by_ids(ids=None, include=["content"], collection_name=clxn.name)
assert len(docs) == 4
def test_retrieve_docs_empty(db_with_indexed_clxn):
db, clxn = db_with_indexed_clxn
assert db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=2) == []
def test_retrieve_docs_populated_db_empty_query(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
db.insert_docs(example_documents, collection_name=clxn.name)
# Empty list of queries returns empty list of results
results = db.retrieve_docs(queries=[], collection_name=clxn.name, n_results=2)
assert results == []
def test_retrieve_docs(db_with_indexed_clxn, example_documents):
"""Begin testing Atlas Vector Search
NOTE: Indexing may take some time, so we must be patient on the first query.
We have the wait_until_index_ready flag to ensure index is created and ready
Immediately adding documents and then querying is only standard for testing
"""
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
n_results = 2 # Number of closest docs to return
def results_ready():
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
return len(results[0]) == n_results
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
assert all(["embedding" not in doc[0] for doc in results[0]])
def test_retrieve_docs_with_embedding(db_with_indexed_clxn, example_documents):
"""Begin testing Atlas Vector Search
NOTE: Indexing may take some time, so we must be patient on the first query.
We have the wait_until_index_ready flag to ensure index is created and ready
Immediately adding documents and then querying is only standard for testing
"""
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
n_results = 2 # Number of closest docs to return
def results_ready():
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results)
return len(results[0]) == n_results
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
results = db.retrieve_docs(queries=["Cats"], collection_name=clxn.name, n_results=n_results, include_embedding=True)
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
assert all(["embedding" in doc[0] for doc in results[0]])
def test_retrieve_docs_multiple_queries(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
n_results = 2 # Number of closest docs to return
queries = ["Some good pets", "What kind of Sandwich?"]
def results_ready():
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
return all([len(res) == n_results for res in results])
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=2)
assert len(results) == len(queries)
assert all([len(res) == n_results for res in results])
assert {doc[0]["id"] for doc in results[0]} == {1, 2}
assert {doc[0]["id"] for doc in results[1]} == {"1", "2"}
def test_retrieve_docs_with_threshold(db_with_indexed_clxn, example_documents):
db, clxn = db_with_indexed_clxn
# Insert example documents
db.insert_docs(example_documents, collection_name=clxn.name)
n_results = 2 # Number of closest docs to return
queries = ["Cats"]
def results_ready():
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results)
return len(results[0]) == n_results
_wait_for_predicate(results_ready, f"Failed to retrieve docs after waiting {TIMEOUT} seconds after each.")
# Distance Threshold of .3 means that the score must be .7 or greater
# only one result should be that value
results = db.retrieve_docs(queries=queries, collection_name=clxn.name, n_results=n_results, distance_threshold=0.3)
assert len(results[0]) == 1
assert all([doc[1] >= 0.7 for doc in results[0]])
def test_wait_until_document_ready(collection_name, example_documents):
database = MongoClient(MONGODB_URI)[MONGODB_DATABASE]
_empty_collections_and_delete_indexes(database, [collection_name], wait=True)
try:
vectorstore = MongoDBAtlasVectorDB(
connection_string=MONGODB_URI,
database_name=MONGODB_DATABASE,
wait_until_index_ready=TIMEOUT,
collection_name=collection_name,
overwrite=True,
wait_until_document_ready=TIMEOUT,
)
vectorstore.insert_docs(example_documents)
assert vectorstore.retrieve_docs(queries=["Cats"], n_results=4)
finally:
_empty_collections_and_delete_indexes(database, [collection_name])