[Re-Opened] Support for PGVector Database in Autogen (#2439)

* PGVector Contrib Initial Commit - KnucklesTeam:autogen:pgvector_contrib fork

* Update website/docs/ecosystem/pgvector.md

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* Updated qdrant installation instructions.

* Fixed openai version.

* Added dependencies to install for qdrant and pgvector in contrib tests.

* Added dependencies to install for qdrant and pgvector in contrib tests.

* Cleaned up dependencies.

* Removed flaml out of setup.py. Used only for notebook example.

* Added PGVector notebook link

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Audel Rouhi 2024-04-18 13:00:03 -07:00 committed by GitHub
parent d5e30e09e8
commit ded2d612c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 3912 additions and 11 deletions

View File

@ -41,11 +41,15 @@ jobs:
pip install -e .
python -c "import autogen"
pip install coverage pytest-asyncio
- name: Install PostgreSQL
run: |
sudo apt install postgresql -y
- name: Start PostgreSQL service
run: sudo service postgresql start
- name: Install packages for test when needed
run: |
pip install docker
pip install qdrant_client[fastembed]
pip install -e .[retrievechat]
pip install -e .[retrievechat-qdrant,retrievechat-pgvector]
- name: Coverage
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -53,7 +57,7 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: |
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat test/agentchat/contrib/test_pgvector_retrievechat.py::test_retrievechat
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3

View File

@ -42,16 +42,20 @@ jobs:
- name: Install qdrant_client when python-version is 3.10
if: matrix.python-version == '3.10'
run: |
pip install qdrant_client[fastembed]
pip install .[retrievechat-qdrant]
- name: Install unstructured when python-version is 3.9 and on linux
if: matrix.python-version == '3.9' && matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr poppler-utils
pip install unstructured[all-docs]==0.13.0
- name: Install packages and dependencies for RetrieveChat
- name: Install and Start PostgreSQL
runs-on: ubuntu-latest
run: |
pip install -e .[retrievechat]
sudo apt install postgresql -y
sudo service postgresql start
- name: Install packages and dependencies for PGVector
run: |
pip install -e .[retrievechat-pgvector]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |

View File

@ -185,7 +185,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
PREDEFINED_VECTOR_DB = ["chroma"]
PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@ -203,6 +203,10 @@ class VectorDBFactory:
from .chromadb import ChromaVectorDB
return ChromaVectorDB(**kwargs)
if db_type.lower() in ["pgvector", "pgvectordb"]:
from .pgvectordb import PGVectorDB
return PGVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."

View File

@ -0,0 +1,736 @@
import os
import re
from typing import Callable, List
import numpy as np
from sentence_transformers import SentenceTransformer
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger
try:
import pgvector
from pgvector.psycopg import register_vector
except ImportError:
raise ImportError("Please install pgvector: `pip install pgvector`")
try:
import psycopg
except ImportError:
raise ImportError("Please install pgvector: `pip install psycopg`")
PGVECTOR_MAX_BATCH_SIZE = os.environ.get("PGVECTOR_MAX_BATCH_SIZE", 40000)
logger = get_logger(__name__)
class Collection:
"""
A Collection object for PGVector.
Attributes:
client: The PGVector client.
collection_name (str): The name of the collection. Default is "documents".
embedding_function (Callable): The embedding function used to generate the vector representation.
metadata (Optional[dict]): The metadata of the collection.
get_or_create (Optional): The flag indicating whether to get or create the collection.
"""
def __init__(
self,
client=None,
collection_name: str = "autogen-docs",
embedding_function: Callable = None,
metadata=None,
get_or_create=None,
):
"""
Initialize the Collection object.
Args:
client: The PostgreSQL client.
collection_name: The name of the collection. Default is "documents".
embedding_function: The embedding function used to generate the vector representation.
metadata: The metadata of the collection.
get_or_create: The flag indicating whether to get or create the collection.
Returns:
None
"""
self.client = client
self.embedding_function = embedding_function
self.name = self.set_collection_name(collection_name)
self.require_embeddings_or_documents = False
self.ids = []
self.embedding_function = (
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function
)
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
self.documents = ""
self.get_or_create = get_or_create
def set_collection_name(self, collection_name):
name = re.sub("-", "_", collection_name)
self.name = name
return self.name
def add(self, ids: List[ItemID], embeddings: List, metadatas: List, documents: List):
"""
Add documents to the collection.
Args:
ids (List[ItemID]): A list of document IDs.
embeddings (List): A list of document embeddings.
metadatas (List): A list of document metadatas.
documents (List): A list of documents.
Returns:
None
"""
cursor = self.client.cursor()
sql_values = []
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
sql_values.append((doc_id, embedding, metadata, document))
sql_string = f"INSERT INTO {self.name} (id, embedding, metadata, document) " f"VALUES (%s, %s, %s, %s);"
cursor.executemany(sql_string, sql_values)
cursor.close()
def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
"""
Upsert documents into the collection.
Args:
ids (List[ItemID]): A list of document IDs.
documents (List): A list of documents.
embeddings (List): A list of document embeddings.
metadatas (List): A list of document metadatas.
Returns:
None
"""
cursor = self.client.cursor()
sql_values = []
if embeddings is not None and metadatas is not None:
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
metadata = re.sub("'", '"', str(metadata))
sql_values.append((doc_id, embedding, metadata, document, embedding, metadata, document))
sql_string = (
f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n"
f"VALUES (%s, %s, %s, %s)\n"
f"ON CONFLICT (id)\n"
f"DO UPDATE SET embedding = %s,\n"
f"metadatas = %s, documents = %s;\n"
)
elif embeddings is not None:
for doc_id, embedding, document in zip(ids, embeddings, documents):
sql_values.append((doc_id, embedding, document, embedding, document))
sql_string = (
f"INSERT INTO {self.name} (id, embedding, documents) "
f"VALUES (%s, %s, %s) ON CONFLICT (id)\n"
f"DO UPDATE SET embedding = %s, documents = %s;\n"
)
elif metadatas is not None:
for doc_id, metadata, document in zip(ids, metadatas, documents):
metadata = re.sub("'", '"', str(metadata))
embedding = self.embedding_function.encode(document)
sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
sql_string = (
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
f"VALUES (%s, %s, %s, %s)\n"
f"ON CONFLICT (id)\n"
f"DO UPDATE SET metadatas = %s, documents = %s, embedding = %s;\n"
)
else:
for doc_id, document in zip(ids, documents):
embedding = self.embedding_function.encode(document)
sql_values.append((doc_id, document, embedding, document))
sql_string = (
f"INSERT INTO {self.name} (id, documents, embedding)\n"
f"VALUES (%s, %s, %s)\n"
f"ON CONFLICT (id)\n"
f"DO UPDATE SET documents = %s;\n"
)
logger.debug(f"Upsert SQL String:\n{sql_string}\n{sql_values}")
cursor.executemany(sql_string, sql_values)
cursor.close()
def count(self):
"""
Get the total number of documents in the collection.
Returns:
int: The total number of documents.
"""
cursor = self.client.cursor()
query = f"SELECT COUNT(*) FROM {self.name}"
cursor.execute(query)
total = cursor.fetchone()[0]
cursor.close()
try:
total = int(total)
except (TypeError, ValueError):
total = None
return total
def get(self, ids=None, include=None, where=None, limit=None, offset=None):
"""
Retrieve documents from the collection.
Args:
ids (Optional[List]): A list of document IDs.
include (Optional): The fields to include.
where (Optional): Additional filtering criteria.
limit (Optional): The maximum number of documents to retrieve.
offset (Optional): The offset for pagination.
Returns:
List: The retrieved documents.
"""
cursor = self.client.cursor()
if include:
query = f'SELECT (id, {", ".join(map(str, include))}, embedding) FROM {self.name}'
else:
query = f"SELECT * FROM {self.name}"
if ids:
query = f"{query} WHERE id IN {ids}"
elif where:
query = f"{query} WHERE {where}"
if offset:
query = f"{query} OFFSET {offset}"
if limit:
query = f"{query} LIMIT {limit}"
retreived_documents = []
try:
cursor.execute(query)
retrieval = cursor.fetchall()
for retrieved_document in retrieval:
retreived_documents.append(
Document(
id=retrieved_document[0][0],
metadata=retrieved_document[0][1],
content=retrieved_document[0][2],
embedding=retrieved_document[0][3],
)
)
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn):
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead.")
self.create_collection(collection_name=self.name)
logger.info(f"Created table {self.name}")
cursor.close()
return retreived_documents
def update(self, ids: List, embeddings: List, metadatas: List, documents: List):
"""
Update documents in the collection.
Args:
ids (List): A list of document IDs.
embeddings (List): A list of document embeddings.
metadatas (List): A list of document metadatas.
documents (List): A list of documents.
Returns:
None
"""
cursor = self.client.cursor()
sql_values = []
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
sql_values.append((doc_id, embedding, metadata, document, doc_id, embedding, metadata, document))
sql_string = (
f"INSERT INTO {self.name} (id, embedding, metadata, document) "
f"VALUES (%s, %s, %s, %s) "
f"ON CONFLICT (id) "
f"DO UPDATE SET id = %s, embedding = %s, "
f"metadata = %s, document = %s;\n"
)
logger.debug(f"Upsert SQL String:\n{sql_string}\n")
cursor.executemany(sql_string, sql_values)
cursor.close()
@staticmethod
def euclidean_distance(arr1: List[float], arr2: List[float]) -> float:
"""
Calculate the Euclidean distance between two vectors.
Parameters:
- arr1 (List[float]): The first vector.
- arr2 (List[float]): The second vector.
Returns:
- float: The Euclidean distance between arr1 and arr2.
"""
dist = np.linalg.norm(arr1 - arr2)
return dist
@staticmethod
def cosine_distance(arr1: List[float], arr2: List[float]) -> float:
"""
Calculate the cosine distance between two vectors.
Parameters:
- arr1 (List[float]): The first vector.
- arr2 (List[float]): The second vector.
Returns:
- float: The cosine distance between arr1 and arr2.
"""
dist = np.dot(arr1, arr2) / (np.linalg.norm(arr1) * np.linalg.norm(arr2))
return dist
@staticmethod
def inner_product_distance(arr1: List[float], arr2: List[float]) -> float:
"""
Calculate the Euclidean distance between two vectors.
Parameters:
- arr1 (List[float]): The first vector.
- arr2 (List[float]): The second vector.
Returns:
- float: The Euclidean distance between arr1 and arr2.
"""
dist = np.linalg.norm(arr1 - arr2)
return dist
def query(
self,
query_texts: List[str],
collection_name: str = None,
n_results: int = 10,
distance_type: str = "euclidean",
distance_threshold: float = -1,
) -> QueryResults:
"""
Query documents in the collection.
Args:
query_texts (List[str]): A list of query texts.
collection_name (Optional[str]): The name of the collection.
n_results (int): The maximum number of results to return.
distance_type (Optional[str]): Distance search type - euclidean or cosine
distance_threshold (Optional[float]): Distance threshold to limit searches
Returns:
QueryResults: The query results.
"""
if collection_name:
self.name = collection_name
if distance_threshold == -1:
distance_threshold = ""
elif distance_threshold > 0:
distance_threshold = f"< {distance_threshold}"
cursor = self.client.cursor()
results = []
for query in query_texts:
vector = self.embedding_function.encode(query, convert_to_tensor=False).tolist()
if distance_type.lower() == "cosine":
index_function = "<=>"
elif distance_type.lower() == "euclidean":
index_function = "<->"
elif distance_type.lower() == "inner-product":
index_function = "<#>"
else:
index_function = "<->"
query = (
f"SELECT id, documents, embedding, metadatas FROM {self.name}\n"
f"ORDER BY embedding {index_function} '{str(vector)}'::vector {distance_threshold}\n"
f"LIMIT {n_results}"
)
cursor.execute(query)
for row in cursor.fetchall():
fetched_document = Document(id=row[0], content=row[1], embedding=row[2], metadata=row[3])
fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
if distance_type.lower() == "cosine":
distance = self.cosine_distance(fetched_document_array, vector)
elif distance_type.lower() == "euclidean":
distance = self.euclidean_distance(fetched_document_array, vector)
elif distance_type.lower() == "inner-product":
distance = self.inner_product_distance(fetched_document_array, vector)
else:
distance = self.euclidean_distance(fetched_document_array, vector)
results.append((fetched_document, distance))
cursor.close()
results = [results]
logger.debug(f"Query Results: {results}")
return results
@staticmethod
def convert_string_to_array(array_string) -> List[float]:
"""
Convert a string representation of an array to a list of floats.
Parameters:
- array_string (str): The string representation of the array.
Returns:
- list: A list of floats parsed from the input string. If the input is
not a string, it returns the input itself.
"""
if not isinstance(array_string, str):
return array_string
array_string = array_string.strip("[]")
array = [float(num) for num in array_string.split()]
return array
def modify(self, metadata, collection_name: str = None):
"""
Modify metadata for the collection.
Args:
collection_name: The name of the collection.
metadata: The new metadata.
Returns:
None
"""
if collection_name:
self.name = collection_name
cursor = self.client.cursor()
cursor.execute(
"UPDATE collections" "SET metadata = '%s'" "WHERE collection_name = '%s';", (metadata, self.name)
)
cursor.close()
def delete(self, ids: List[ItemID], collection_name: str = None):
"""
Delete documents from the collection.
Args:
ids (List[ItemID]): A list of document IDs to delete.
collection_name (str): The name of the collection to delete.
Returns:
None
"""
if collection_name:
self.name = collection_name
cursor = self.client.cursor()
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({ids});")
cursor.close()
def delete_collection(self, collection_name: str = None):
"""
Delete the entire collection.
Args:
collection_name (Optional[str]): The name of the collection to delete.
Returns:
None
"""
if collection_name:
self.name = collection_name
cursor = self.client.cursor()
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
cursor.close()
def create_collection(self, collection_name: str = None):
"""
Create a new collection.
Args:
collection_name (Optional[str]): The name of the new collection.
Returns:
None
"""
if collection_name:
self.name = collection_name
cursor = self.client.cursor()
cursor.execute(
f"CREATE TABLE {self.name} ("
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));"
f"CREATE INDEX "
f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
f"CREATE INDEX "
f'ON {self.name} USING hnsw (embedding vector_cosine_ops) WITH (m = {self.metadata["hnsw:M"]}, '
f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
f"CREATE INDEX "
f'ON {self.name} USING hnsw (embedding vector_ip_ops) WITH (m = {self.metadata["hnsw:M"]}, '
f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
)
cursor.close()
class PGVectorDB(VectorDB):
"""
A vector database that uses PGVector as the backend.
"""
def __init__(
self,
*,
connection_string: str = None,
host: str = None,
port: int = None,
dbname: str = None,
connect_timeout: int = 10,
embedding_function: Callable = None,
metadata: dict = None,
) -> None:
"""
Initialize the vector database.
Note: connection_string or host + port + dbname must be specified
Args:
connection_string: "postgresql://username:password@hostname:port/database" | The PGVector connection string. Default is None.
host: str | The host to connect to. Default is None.
port: int | The port to connect to. Default is None.
dbname: str | The database name to connect to. Default is None.
connect_timeout: int | The timeout to set for the connection. Default is 10.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None.
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
kwargs: dict | Additional keyword arguments.
Returns:
None
"""
if connection_string:
self.client = psycopg.connect(conninfo=connection_string, autocommit=True)
elif host and port and dbname:
self.client = psycopg.connect(
host=host, port=port, dbname=dbname, connect_timeout=connect_timeout, autocommit=True
)
self.embedding_function = (
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function
)
self.metadata = metadata
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(self.client)
self.active_collection = None
def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True
) -> Collection:
"""
Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Collection | The collection object.
"""
try:
if self.active_collection and self.active_collection.name == collection_name:
collection = self.active_collection
else:
collection = self.get_collection(collection_name)
except ValueError:
collection = None
if collection is None:
collection = Collection(
collection_name=collection_name,
embedding_function=self.embedding_function,
get_or_create=get_or_create,
metadata=self.metadata,
)
collection.set_collection_name(collection_name=collection_name)
collection.create_collection(collection_name=collection_name)
return collection
elif overwrite:
self.delete_collection(collection_name)
collection = Collection(
collection_name=collection_name,
embedding_function=self.embedding_function,
get_or_create=get_or_create,
metadata=self.metadata,
)
collection.set_collection_name(collection_name=collection_name)
collection.create_collection(collection_name=collection_name)
return collection
elif get_or_create:
return collection
else:
raise ValueError(f"Collection {collection_name} already exists.")
def get_collection(self, collection_name: str = None) -> Collection:
"""
Get the collection from the vector database.
Args:
collection_name: str | The name of the collection. Default is None. If None, return the
current active collection.
Returns:
Collection | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.debug(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
self.active_collection = Collection(
client=self.client, collection_name=collection_name, embedding_function=self.embedding_function
)
return self.active_collection
def delete_collection(self, collection_name: str) -> None:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
None
"""
self.active_collection.delete_collection(collection_name)
if self.active_collection and self.active_collection.name == collection_name:
self.active_collection = None
def _batch_insert(
self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
):
batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
default_metadatas = [default_metadata]
for i in range(0, len(documents), min(batch_size, len(documents))):
end_idx = i + min(batch_size, len(documents) - i)
collection_kwargs = {
"documents": documents[i:end_idx],
"ids": ids[i:end_idx],
"metadatas": metadatas[i:end_idx] if metadatas else default_metadatas,
"embeddings": embeddings[i:end_idx] if embeddings else None,
}
if upsert:
collection.upsert(**collection_kwargs)
else:
collection.add(**collection_kwargs)
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict | Additional keyword arguments.
Returns:
None
"""
if not docs:
return
if docs[0].get("content") is None:
raise ValueError("The document content is required.")
if docs[0].get("id") is None:
raise ValueError("The document id is required.")
documents = [doc.get("content") for doc in docs]
ids = [doc.get("id") for doc in docs]
collection = self.get_collection(collection_name)
if docs[0].get("embedding") is None:
logger.debug(
"No content embedding is provided. "
"Will use the VectorDB's embedding function to generate the content embedding."
)
embeddings = None
else:
embeddings = [doc.get("embedding") for doc in docs]
if docs[0].get("metadata") is None:
metadatas = None
else:
metadatas = [doc.get("metadata") for doc in docs]
self._batch_insert(collection, embeddings, ids, metadatas, documents, upsert)
def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: str | The name of the collection. Default is None.
Returns:
None
"""
self.insert_docs(docs, collection_name, upsert=True)
def delete_docs(self, ids: List[ItemID], collection_name: str = None) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: str | The name of the collection. Default is None.
kwargs: Dict | Additional keyword arguments.
Returns:
None
"""
collection = self.get_collection(collection_name)
collection.delete(ids=ids, collection_name=collection_name)
def retrieve_docs(
self,
queries: List[str],
collection_name: str = None,
n_results: int = 10,
distance_threshold: float = -1,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: str | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
collection = self.get_collection(collection_name)
if isinstance(queries, str):
queries = [queries]
results = collection.query(
query_texts=queries,
n_results=n_results,
distance_threshold=distance_threshold,
)
logger.debug(f"Retrieve Docs Results:\n{results}")
return results
def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: List[ItemID] | A list of document ids.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included.
kwargs: dict | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
collection = self.get_collection(collection_name)
include = include if include else ["metadatas", "documents"]
results = collection.get(ids, include=include, **kwargs)
logger.debug(f"Retrieve Documents by ID Results:\n{results}")
return results

File diff suppressed because one or more lines are too long

View File

@ -21,7 +21,7 @@
"Some extra dependencies are needed for this notebook, which can be installed via pip:\n",
"\n",
"```bash\n",
"pip install \"pyautogen[retrievechat]>=0.2.3\" \"flaml[automl]\" \"qdrant_client[fastembed]\"\n",
"pip install \"pyautogen[retrievechat-qdrant]\" \"flaml[automl]\"\n",
"```\n",
"\n",
"For more information, please refer to the [installation guide](/docs/installation/).\n",
@ -196,7 +196,7 @@
}
],
"source": [
"%pip install \"pyautogen[retrievechat]>=0.2.3\" \"flaml[automl]\" \"qdrant_client[fastembed]\""
"%pip install \"pyautogen[retrievechat-qdrant]\" \"flaml[automl]\""
]
},
{

View File

@ -60,6 +60,23 @@ setuptools.setup(
"blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": ["chromadb", "sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"],
"retrievechat-pgvector": [
"pgvector>=0.2.5",
"psycopg>=3.1.18",
"sentence_transformers",
"pypdf",
"ipython",
"beautifulsoup4",
"markdownify",
],
"retrievechat-qdrant": [
"qdrant_client[fastembed]",
"sentence_transformers",
"pypdf",
"ipython",
"beautifulsoup4",
"markdownify",
],
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],

View File

@ -0,0 +1,130 @@
#!/usr/bin/env python3 -m pytest
import os
import sys
import pytest
from sentence_transformers import SentenceTransformer
from autogen import config_list_from_json
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
try:
import pgvector
from autogen.agentchat.contrib.retrieve_assistant_agent import (
RetrieveAssistantAgent,
)
from autogen.agentchat.contrib.retrieve_user_proxy_agent import (
RetrieveUserProxyAgent,
)
except ImportError:
skip = True
else:
skip = False or skip_openai
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files")
@pytest.mark.skipif(
skip,
reason="dependency is not installed OR requested to skip",
)
def test_retrievechat():
conversations = {}
# ChatCompletion.start_logging(conversations) # deprecated in v0.2
config_list = config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
)
assistant = RetrieveAssistantAgent(
name="assistant",
system_message="You are a helpful assistant.",
llm_config={
"timeout": 600,
"seed": 42,
"config_list": config_list,
},
)
sentence_transformer_ef = SentenceTransformer("all-MiniLM-L6-v2")
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=3,
retrieve_config={
"task": "code",
"docs_path": [
"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md",
"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md",
"https://raw.githubusercontent.com/Knuckles-Team/geniusbot/main/README.md",
"https://raw.githubusercontent.com/Knuckles-Team/repository-manager/main/README.md",
"https://raw.githubusercontent.com/Knuckles-Team/gitlab-api/main/README.md",
"https://raw.githubusercontent.com/Knuckles-Team/media-downloader/main/README.md",
os.path.join(os.path.abspath(""), "..", "website", "docs"),
],
"custom_text_types": ["non-existent-type"],
"chunk_token_size": 2000,
"model": config_list[0]["model"],
"vector_db": "pgvector", # PGVector database
"collection_name": "test_collection",
"db_config": {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
},
"embedding_function": sentence_transformer_ef,
"get_or_create": True, # set to False if you don't want to reuse an existing collection
"overwrite": False, # set to True if you want to overwrite an existing collection
},
code_execution_config=False, # set to False if you don't want to execute the code
)
assistant.reset()
code_problem = "How can I use FLAML to perform a classification task, set use_spark=True, train 30 seconds and force cancel jobs if time limit is reached."
ragproxyagent.initiate_chat(
assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string="spark", silent=True
)
print(conversations)
@pytest.mark.skipif(
skip,
reason="dependency is not installed",
)
def test_retrieve_config(caplog):
# test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
retrieve_config={
"chunk_token_size": 2000,
"get_or_create": True,
},
)
# Capture the printed content
captured_logs = caplog.records[0]
print(captured_logs)
# Assert on the printed content
assert (
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
in captured_logs.message
)
assert captured_logs.levelname == "WARNING"
if __name__ == "__main__":
test_retrievechat()
test_retrieve_config()

View File

@ -0,0 +1,82 @@
import os
import sys
import pytest
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
try:
import pgvector
import sentence_transformers
from autogen.agentchat.contrib.vectordb.pgvector import PGVector
except ImportError:
skip = True
else:
skip = False
@pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip")
def test_pgvector():
# test create collection
db_config = {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
}
db = PGVector(connection_string=db_config["connection_string"])
collection_name = "test_collection"
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name
# test_delete_collection
db.delete_collection(collection_name)
pytest.raises(ValueError, db.get_collection, collection_name)
# test more create collection
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
assert collection.name == collection_name
pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False)
collection = db.create_collection(collection_name, overwrite=True, get_or_create=False)
assert collection.name == collection_name
collection = db.create_collection(collection_name, overwrite=False, get_or_create=True)
assert collection.name == collection_name
# test_get_collection
collection = db.get_collection(collection_name)
assert collection.name == collection_name
# test_insert_docs
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.insert_docs(docs, collection_name, upsert=False)
res = db.get_collection(collection_name).get(["1", "2"])
assert res["documents"] == ["doc1", "doc2"]
# test_update_docs
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.update_docs(docs, collection_name)
res = db.get_collection(collection_name).get(["1", "2"])
assert res["documents"] == ["doc11", "doc2"]
# test_delete_docs
ids = ["1"]
collection_name = "test_collection"
db.delete_docs(ids, collection_name)
res = db.get_collection(collection_name).get(ids)
assert res["documents"] == []
# test_retrieve_docs
queries = ["doc2", "doc3"]
collection_name = "test_collection"
res = db.retrieve_docs(queries, collection_name)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
print(res)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]]
# test_get_docs_by_ids
res = db.get_docs_by_ids(["1", "2"], collection_name)
assert [r["id"] for r in res] == ["2"] # "1" has been deleted
if __name__ == "__main__":
test_pgvector()

View File

@ -0,0 +1,5 @@
# PGVector
[PGVector](https://github.com/pgvector/pgvector) is an open-source vector similarity search for Postgres.
- [PGVector + AutoGen Code Examples](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_pgvector_RetrieveChat.ipynb)

View File

@ -44,12 +44,22 @@ Example notebooks:
## retrievechat
`pyautogen` supports retrieval-augmented generation tasks such as question answering and code generation with RAG agents. Please install with the [retrievechat] option to use it.
`pyautogen` supports retrieval-augmented generation tasks such as question answering and code generation with RAG agents. Please install with the [retrievechat] option to use it with ChromaDB.
```bash
pip install "pyautogen[retrievechat]"
```
Alternatively `pyautogen` also supports PGVector and Qdrant which can be installed in place of ChromaDB, or alongside it.
```bash
pip install "pyautogen[retrievechat-pgvector]"
```
```bash
pip install "pyautogen[retrievechat-qdrant]"
```
RetrieveChat can handle various types of documents. By default, it can process
plain text and PDF files, including formats such as 'txt', 'json', 'csv', 'tsv',
'md', 'html', 'htm', 'rtf', 'rst', 'jsonl', 'log', 'xml', 'yaml', 'yml' and 'pdf'.

View File

@ -53,12 +53,73 @@ ragproxyagent.initiate_chat(
) # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain "spark".
```
## Example Setup: RAG with Retrieval Augmented Agents with PGVector
The following is an example setup demonstrating how to create retrieval augmented agents in AutoGen:
### Step 1. Create an instance of `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`.
Here `RetrieveUserProxyAgent` instance acts as a proxy agent that retrieves relevant information based on the user's input.
Specify the connection_string, or the host, port, database, username, and password in the db_config.
```python
assistant = RetrieveAssistantAgent(
name="assistant",
system_message="You are a helpful assistant.",
llm_config={
"timeout": 600,
"cache_seed": 42,
"config_list": config_list,
},
)
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=3,
retrieve_config={
"task": "code",
"docs_path": [
"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Examples/Integrate%20-%20Spark.md",
"https://raw.githubusercontent.com/microsoft/FLAML/main/website/docs/Research.md",
os.path.join(os.path.abspath(""), "..", "website", "docs"),
],
"vector_db": "pgvector",
"collection_name": "autogen_docs",
"db_config": {
"connection_string": "postgresql://testuser:testpwd@localhost:5432/vectordb", # Optional - connect to an external vector database
# "host": None, # Optional vector database host
# "port": None, # Optional vector database port
# "database": None, # Optional vector database name
# "username": None, # Optional vector database username
# "password": None, # Optional vector database password
},
"custom_text_types": ["mdx"],
"chunk_token_size": 2000,
"model": config_list[0]["model"],
"get_or_create": True,
},
code_execution_config=False,
)
```
### Step 2. Initiating Agent Chat with Retrieval Augmentation
Once the retrieval augmented agents are set up, you can initiate a chat with retrieval augmentation using the following code:
```python
code_problem = "How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached."
ragproxyagent.initiate_chat(
assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string="spark"
) # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain "spark".
```
## Online Demo
[Retrival-Augmented Chat Demo on Huggingface](https://huggingface.co/spaces/thinkall/autogen-demos)
## More Examples and Notebooks
For more detailed examples and notebooks showcasing the usage of retrieval augmented agents in AutoGen, refer to the following:
- Automated Code Generation and Question Answering with Retrieval Augmented Agents - [View Notebook](/docs/notebooks/agentchat_RetrieveChat)
- Automated Code Generation and Question Answering with [PGVector](https://github.com/pgvector/pgvector) based Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_pgvector_RetrieveChat.ipynb)
- Automated Code Generation and Question Answering with [Qdrant](https://qdrant.tech/) based Retrieval Augmented Agents - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb)
- Chat with OpenAI Assistant with Retrieval Augmentation - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_oai_assistant_retrieval.ipynb)
- **RAG**: Group Chat with Retrieval Augmented Generation (with 5 group member agents and 1 manager agent) - [View Notebook](/docs/notebooks/agentchat_groupchat_RAG)