2447 fix pgvector tests and notebook (#2455)

* Re-added missing notebook

* Test installing postgres

* Error handle the connection.

* Fixed import.

* Fixed import.

* Fixed creation of collection without client.

* PGVector portion working. OpenAI untested.

* Fixed prints.

* Added output.

* Fixed pre-commits.

* Run pgvector notebook

* Improve efficiency of get_collection

* Fix delete_collection

* Fixed issues with pytests and validated functions.

* Validated pytests.

* Fixed pre-commits

* Separated extra_requires to allow more logic. Retrieve_chat base dependencies included on pgvector and qdrant.

* Fixed extra newline.

* Added username and password fields.

* URL Encode the connection string parameters to support symbols like %

* Fixed pre-commits.

* Added pgvector service

* pgvector doesn't have health intervals.

* Switched to colon based key values.

* Run on Ubuntu only. Linux is only option with container service support.

* Using default credentials instead.

* Fix postgres setup

* Fix postgres setup

* Don't skip tests on win and mac

* Fix command error

* Try apt install postgresql

* Assert table does not exist when deleted.

* Raise value error on a empty list or None value provided for IDs

* pre-commit

* Add install pgvector

* Add install pgvector

* Reorg test files, create a separate job for test pgvector

* Fix format

* Fix env format

* Simplify job name, enable test_retrieve_config

* Fix test_retrieve_config

* Corrected behavior for get_docs_by_ids with no ids returning all docs.

* Corrected behavior for get_docs_by_ids with no ids returning all docs.

* Fixed pre-commits.

* Added return values for all functions.

* Validated distance search is implemented correctly.

* Validated all pytests

* Removed print.

* Added default clause.

* Make ids optional

* Fix test, make it more robust

* Bump version of openai for the vector_store support

* Added support for choosing the sentence transformer model.

* Added error handling for model name entered.

* Updated model info.

* Added model_name db_config param.

* pre-commit fixes and last link fix.

* Use secrets password.

* fix: link fixed

* updated tests

* Updated config_list.

* pre-commit fix.

* Added chat_result to all output.
Unable to re-run notebooks.

* Pre-commit fix detected this requirement.

* Fix python 3.8 and 3.9 not supported for macos

* Fix python 3.8 and 3.9 not supported for macos

* Fix format

* Reran notebook with MetaLlama3Instruct7BQ4_k_M

* added gpt model.

* Reran notebook

---------

Co-authored-by: Li Jiang <bnujli@gmail.com>
Co-authored-by: Hk669 <hrushi669@gmail.com>
This commit is contained in:
Audel Rouhi 2024-04-28 06:43:02 -07:00 committed by GitHub
parent 600bd3f2fe
commit 1b8d65df0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1014 additions and 2198 deletions

View File

@ -24,6 +24,21 @@ jobs:
python-version: ["3.10"] python-version: ["3.10"]
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
environment: openai1 environment: openai1
services:
pgvector:
image: ankane/pgvector
env:
POSTGRES_DB: postgres
POSTGRES_USER: postgres
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
POSTGRES_HOST_AUTH_METHOD: trust
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps: steps:
# checkout to pr branch # checkout to pr branch
- name: Checkout - name: Checkout
@ -41,15 +56,10 @@ jobs:
pip install -e . pip install -e .
python -c "import autogen" python -c "import autogen"
pip install coverage pytest-asyncio pip install coverage pytest-asyncio
- name: Install PostgreSQL
run: |
sudo apt install postgresql -y
- name: Start PostgreSQL service
run: sudo service postgresql start
- name: Install packages for test when needed - name: Install packages for test when needed
run: | run: |
pip install docker pip install docker
pip install -e .[retrievechat-qdrant,retrievechat-pgvector] pip install -e .[retrievechat,retrievechat-qdrant,retrievechat-pgvector]
- name: Coverage - name: Coverage
env: env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
@ -57,13 +67,14 @@ jobs:
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }} AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }} OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
run: | run: |
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat test/agentchat/contrib/test_pgvector_retrievechat.py::test_retrievechat coverage run -a -m pytest -k test_retrievechat test/agentchat/contrib/retrievechat
coverage xml coverage xml
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3
with: with:
file: ./coverage.xml file: ./coverage.xml
flags: unittests flags: unittests
CompressionTest: CompressionTest:
strategy: strategy:
matrix: matrix:

View File

@ -27,7 +27,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-2019] os: [macos-latest, windows-2019]
python-version: ["3.9", "3.10", "3.11"] python-version: ["3.9", "3.10", "3.11"]
exclude: exclude:
- os: macos-latest - os: macos-latest
@ -45,20 +45,10 @@ jobs:
- name: Install qdrant_client when python-version is 3.10 - name: Install qdrant_client when python-version is 3.10
if: matrix.python-version == '3.10' if: matrix.python-version == '3.10'
run: | run: |
pip install .[retrievechat-qdrant] pip install -e .[retrievechat-qdrant]
- name: Install unstructured when python-version is 3.9 and on linux - name: Install packages and dependencies for RetrieveChat
run: | run: |
sudo apt-get update pip install -e .[retrievechat]
sudo apt-get install -y tesseract-ocr poppler-utils
pip install unstructured[all-docs]==0.13.0
- name: Install and Start PostgreSQL
runs-on: ubuntu-latest
run: |
sudo apt install postgresql -y
sudo service postgresql start
- name: Install packages and dependencies for PGVector
run: |
pip install -e .[retrievechat-pgvector]
- name: Set AUTOGEN_USE_DOCKER based on OS - name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash shell: bash
run: | run: |
@ -68,7 +58,69 @@ jobs:
- name: Coverage - name: Coverage
run: | run: |
pip install coverage>=5.3 pip install coverage>=5.3
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat test/agentchat/contrib/vectordb --skip-openai
coverage xml
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
RetrieveChatTest-Ubuntu:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11"]
services:
pgvector:
image: ankane/pgvector
env:
POSTGRES_DB: postgres
POSTGRES_USER: postgres
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
POSTGRES_HOST_AUTH_METHOD: trust
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 5432:5432
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install packages and dependencies for all tests
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install qdrant_client when python-version is 3.10
if: matrix.python-version == '3.10'
run: |
pip install -e .[retrievechat-qdrant]
- name: Install pgvector when on linux
run: |
pip install -e .[retrievechat-pgvector]
- name: Install unstructured when python-version is 3.9 and on linux
if: matrix.python-version == '3.9'
run: |
sudo apt-get update
sudo apt-get install -y tesseract-ocr poppler-utils
pip install unstructured[all-docs]==0.13.0
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
- name: Set AUTOGEN_USE_DOCKER based on OS
shell: bash
run: |
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
- name: Coverage
run: |
pip install coverage>=5.3
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat test/agentchat/contrib/vectordb --skip-openai
coverage xml coverage xml
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
uses: codecov/codecov-action@v3 uses: codecov/codecov-action@v3

View File

@ -1,5 +1,6 @@
import os import os
import re import re
import urllib.parse
from typing import Callable, List from typing import Callable, List
import numpy as np import numpy as np
@ -33,7 +34,8 @@ class Collection:
embedding_function (Callable): The embedding function used to generate the vector representation. embedding_function (Callable): The embedding function used to generate the vector representation.
metadata (Optional[dict]): The metadata of the collection. metadata (Optional[dict]): The metadata of the collection.
get_or_create (Optional): The flag indicating whether to get or create the collection. get_or_create (Optional): The flag indicating whether to get or create the collection.
model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
https://huggingface.co/models?library=sentence-transformers
""" """
def __init__( def __init__(
@ -43,6 +45,7 @@ class Collection:
embedding_function: Callable = None, embedding_function: Callable = None,
metadata=None, metadata=None,
get_or_create=None, get_or_create=None,
model_name="all-MiniLM-L6-v2",
): ):
""" """
Initialize the Collection object. Initialize the Collection object.
@ -53,46 +56,76 @@ class Collection:
embedding_function: The embedding function used to generate the vector representation. embedding_function: The embedding function used to generate the vector representation.
metadata: The metadata of the collection. metadata: The metadata of the collection.
get_or_create: The flag indicating whether to get or create the collection. get_or_create: The flag indicating whether to get or create the collection.
model_name: | Sentence embedding model to use. Models can be chosen from:
https://huggingface.co/models?library=sentence-transformers
Returns: Returns:
None None
""" """
self.client = client self.client = client
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.model_name = model_name
self.name = self.set_collection_name(collection_name) self.name = self.set_collection_name(collection_name)
self.require_embeddings_or_documents = False self.require_embeddings_or_documents = False
self.ids = [] self.ids = []
self.embedding_function = ( try:
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function self.embedding_function = (
) SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
)
except Exception as e:
logger.error(
f"Validate the model name entered: {self.model_name} "
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
)
raise e
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16} self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
self.documents = "" self.documents = ""
self.get_or_create = get_or_create self.get_or_create = get_or_create
def set_collection_name(self, collection_name): def set_collection_name(self, collection_name) -> str:
name = re.sub("-", "_", collection_name) name = re.sub("-", "_", collection_name)
self.name = name self.name = name
return self.name return self.name
def add(self, ids: List[ItemID], embeddings: List, metadatas: List, documents: List): def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
""" """
Add documents to the collection. Add documents to the collection.
Args: Args:
ids (List[ItemID]): A list of document IDs. ids (List[ItemID]): A list of document IDs.
embeddings (List): A list of document embeddings. embeddings (List): A list of document embeddings. Optional
metadatas (List): A list of document metadatas. metadatas (List): A list of document metadatas. Optional
documents (List): A list of documents. documents (List): A list of documents.
Returns: Returns:
None None
""" """
cursor = self.client.cursor() cursor = self.client.cursor()
sql_values = [] sql_values = []
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents): if embeddings is not None and metadatas is not None:
sql_values.append((doc_id, embedding, metadata, document)) for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
sql_string = f"INSERT INTO {self.name} (id, embedding, metadata, document) " f"VALUES (%s, %s, %s, %s);" metadata = re.sub("'", '"', str(metadata))
sql_values.append((doc_id, embedding, metadata, document))
sql_string = (
f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
)
elif embeddings is not None:
for doc_id, embedding, document in zip(ids, embeddings, documents):
sql_values.append((doc_id, embedding, document))
sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
elif metadatas is not None:
for doc_id, metadata, document in zip(ids, metadatas, documents):
metadata = re.sub("'", '"', str(metadata))
embedding = self.embedding_function.encode(document)
sql_values.append((doc_id, metadata, embedding, document))
sql_string = (
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
)
else:
for doc_id, document in zip(ids, documents):
embedding = self.embedding_function.encode(document)
sql_values.append((doc_id, document, embedding))
sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
cursor.executemany(sql_string, sql_values) cursor.executemany(sql_string, sql_values)
cursor.close() cursor.close()
@ -155,7 +188,7 @@ class Collection:
cursor.executemany(sql_string, sql_values) cursor.executemany(sql_string, sql_values)
cursor.close() cursor.close()
def count(self): def count(self) -> int:
""" """
Get the total number of documents in the collection. Get the total number of documents in the collection.
@ -173,7 +206,32 @@ class Collection:
total = None total = None
return total return total
def get(self, ids=None, include=None, where=None, limit=None, offset=None): def table_exists(self, table_name: str) -> bool:
"""
Check if a table exists in the PostgreSQL database.
Args:
table_name (str): The name of the table to check.
Returns:
bool: True if the table exists, False otherwise.
"""
cursor = self.client.cursor()
cursor.execute(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = %s
)
""",
(table_name,),
)
exists = cursor.fetchone()[0]
return exists
def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]:
""" """
Retrieve documents from the collection. Retrieve documents from the collection.
@ -188,39 +246,65 @@ class Collection:
List: The retrieved documents. List: The retrieved documents.
""" """
cursor = self.client.cursor() cursor = self.client.cursor()
# Initialize variables for query components
select_clause = "SELECT id, metadatas, documents, embedding"
from_clause = f"FROM {self.name}"
where_clause = ""
limit_clause = ""
offset_clause = ""
# Handle include clause
if include: if include:
query = f'SELECT (id, {", ".join(map(str, include))}, embedding) FROM {self.name}' select_clause = f"SELECT id, {', '.join(include)}, embedding"
else:
query = f"SELECT * FROM {self.name}" # Handle where clause
if ids: if ids:
query = f"{query} WHERE id IN {ids}" where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
elif where: elif where:
query = f"{query} WHERE {where}" where_clause = f"WHERE {where}"
if offset:
query = f"{query} OFFSET {offset}" # Handle limit and offset clauses
if limit: if limit:
query = f"{query} LIMIT {limit}" limit_clause = "LIMIT %s"
retreived_documents = [] if offset:
offset_clause = "OFFSET %s"
# Construct the full query
query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
retrieved_documents = []
try: try:
cursor.execute(query) # Execute the query with the appropriate values
if ids is not None:
cursor.execute(query, ids)
else:
query_params = []
if limit:
query_params.append(limit)
if offset:
query_params.append(offset)
cursor.execute(query, query_params)
retrieval = cursor.fetchall() retrieval = cursor.fetchall()
for retrieved_document in retrieval: for retrieved_document in retrieval:
retreived_documents.append( retrieved_documents.append(
Document( Document(
id=retrieved_document[0][0], id=retrieved_document[0].strip(),
metadata=retrieved_document[0][1], metadata=retrieved_document[1],
content=retrieved_document[0][2], content=retrieved_document[2],
embedding=retrieved_document[0][3], embedding=retrieved_document[3],
) )
) )
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn): except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead.") logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
self.create_collection(collection_name=self.name) self.create_collection(collection_name=self.name)
logger.info(f"Created table {self.name}") logger.info(f"Created table {self.name}")
cursor.close()
return retreived_documents
def update(self, ids: List, embeddings: List, metadatas: List, documents: List): cursor.close()
return retrieved_documents
def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
""" """
Update documents in the collection. Update documents in the collection.
@ -300,6 +384,7 @@ class Collection:
n_results: int = 10, n_results: int = 10,
distance_type: str = "euclidean", distance_type: str = "euclidean",
distance_threshold: float = -1, distance_threshold: float = -1,
include_embedding: bool = False,
) -> QueryResults: ) -> QueryResults:
""" """
Query documents in the collection. Query documents in the collection.
@ -310,21 +395,25 @@ class Collection:
n_results (int): The maximum number of results to return. n_results (int): The maximum number of results to return.
distance_type (Optional[str]): Distance search type - euclidean or cosine distance_type (Optional[str]): Distance search type - euclidean or cosine
distance_threshold (Optional[float]): Distance threshold to limit searches distance_threshold (Optional[float]): Distance threshold to limit searches
include_embedding (Optional[bool]): Include embedding values in QueryResults
Returns: Returns:
QueryResults: The query results. QueryResults: The query results.
""" """
if collection_name: if collection_name:
self.name = collection_name self.name = collection_name
clause = "ORDER BY"
if distance_threshold == -1: if distance_threshold == -1:
distance_threshold = "" distance_threshold = ""
clause = "ORDER BY"
elif distance_threshold > 0: elif distance_threshold > 0:
distance_threshold = f"< {distance_threshold}" distance_threshold = f"< {distance_threshold}"
clause = "WHERE"
cursor = self.client.cursor() cursor = self.client.cursor()
results = [] results = []
for query in query_texts: for query_text in query_texts:
vector = self.embedding_function.encode(query, convert_to_tensor=False).tolist() vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist()
if distance_type.lower() == "cosine": if distance_type.lower() == "cosine":
index_function = "<=>" index_function = "<=>"
elif distance_type.lower() == "euclidean": elif distance_type.lower() == "euclidean":
@ -333,15 +422,16 @@ class Collection:
index_function = "<#>" index_function = "<#>"
else: else:
index_function = "<->" index_function = "<->"
query = ( query = (
f"SELECT id, documents, embedding, metadatas FROM {self.name}\n" f"SELECT id, documents, embedding, metadatas "
f"ORDER BY embedding {index_function} '{str(vector)}'::vector {distance_threshold}\n" f"FROM {self.name} "
f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
f"LIMIT {n_results}" f"LIMIT {n_results}"
) )
cursor.execute(query) cursor.execute(query)
result = []
for row in cursor.fetchall(): for row in cursor.fetchall():
fetched_document = Document(id=row[0], content=row[1], embedding=row[2], metadata=row[3]) fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding")) fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
if distance_type.lower() == "cosine": if distance_type.lower() == "cosine":
distance = self.cosine_distance(fetched_document_array, vector) distance = self.cosine_distance(fetched_document_array, vector)
@ -351,9 +441,11 @@ class Collection:
distance = self.inner_product_distance(fetched_document_array, vector) distance = self.inner_product_distance(fetched_document_array, vector)
else: else:
distance = self.euclidean_distance(fetched_document_array, vector) distance = self.euclidean_distance(fetched_document_array, vector)
results.append((fetched_document, distance)) if not include_embedding:
fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
result.append((fetched_document, distance))
results.append(result)
cursor.close() cursor.close()
results = [results]
logger.debug(f"Query Results: {results}") logger.debug(f"Query Results: {results}")
return results return results
@ -375,7 +467,7 @@ class Collection:
array = [float(num) for num in array_string.split()] array = [float(num) for num in array_string.split()]
return array return array
def modify(self, metadata, collection_name: str = None): def modify(self, metadata, collection_name: str = None) -> None:
""" """
Modify metadata for the collection. Modify metadata for the collection.
@ -394,7 +486,7 @@ class Collection:
) )
cursor.close() cursor.close()
def delete(self, ids: List[ItemID], collection_name: str = None): def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
""" """
Delete documents from the collection. Delete documents from the collection.
@ -408,10 +500,11 @@ class Collection:
if collection_name: if collection_name:
self.name = collection_name self.name = collection_name
cursor = self.client.cursor() cursor = self.client.cursor()
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({ids});") id_placeholders = ", ".join(["%s" for _ in ids])
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
cursor.close() cursor.close()
def delete_collection(self, collection_name: str = None): def delete_collection(self, collection_name: str = None) -> None:
""" """
Delete the entire collection. Delete the entire collection.
@ -427,7 +520,7 @@ class Collection:
cursor.execute(f"DROP TABLE IF EXISTS {self.name}") cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
cursor.close() cursor.close()
def create_collection(self, collection_name: str = None): def create_collection(self, collection_name: str = None) -> None:
""" """
Create a new collection. Create a new collection.
@ -468,9 +561,12 @@ class PGVectorDB(VectorDB):
host: str = None, host: str = None,
port: int = None, port: int = None,
dbname: str = None, dbname: str = None,
username: str = None,
password: str = None,
connect_timeout: int = 10, connect_timeout: int = 10,
embedding_function: Callable = None, embedding_function: Callable = None,
metadata: dict = None, metadata: dict = None,
model_name: str = "all-MiniLM-L6-v2",
) -> None: ) -> None:
""" """
Initialize the vector database. Initialize the vector database.
@ -482,6 +578,8 @@ class PGVectorDB(VectorDB):
host: str | The host to connect to. Default is None. host: str | The host to connect to. Default is None.
port: int | The port to connect to. Default is None. port: int | The port to connect to. Default is None.
dbname: str | The database name to connect to. Default is None. dbname: str | The database name to connect to. Default is None.
username: str | The database username to use. Default is None.
password: str | The database user password to use. Default is None.
connect_timeout: int | The timeout to set for the connection. Default is 10. connect_timeout: int | The timeout to set for the connection. Default is 10.
embedding_function: Callable | The embedding function used to generate the vector representation embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None. of the documents. Default is None.
@ -489,20 +587,48 @@ class PGVectorDB(VectorDB):
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef". using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
kwargs: dict | Additional keyword arguments. model_name: str | Sentence embedding model to use. Models can be chosen from:
https://huggingface.co/models?library=sentence-transformers
Returns: Returns:
None None
""" """
if connection_string: try:
self.client = psycopg.connect(conninfo=connection_string, autocommit=True) if connection_string:
elif host and port and dbname: parsed_connection = urllib.parse.urlparse(connection_string)
self.client = psycopg.connect( encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
host=host, port=port, dbname=dbname, connect_timeout=connect_timeout, autocommit=True encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
connection_string_encoded = (
f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
)
self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
elif host and port and dbname:
self.client = psycopg.connect(
host=host,
port=port,
dbname=dbname,
username=username,
password=password,
connect_timeout=connect_timeout,
autocommit=True,
)
except psycopg.Error as e:
logger.error("Error connecting to the database: ", e)
raise e
self.model_name = model_name
try:
self.embedding_function = (
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
) )
self.embedding_function = ( except Exception as e:
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function logger.error(
) f"Validate the model name entered: {self.model_name} "
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
)
raise e
self.metadata = metadata self.metadata = metadata
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector") self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
register_vector(self.client) register_vector(self.client)
@ -535,10 +661,12 @@ class PGVectorDB(VectorDB):
collection = None collection = None
if collection is None: if collection is None:
collection = Collection( collection = Collection(
client=self.client,
collection_name=collection_name, collection_name=collection_name,
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
get_or_create=get_or_create, get_or_create=get_or_create,
metadata=self.metadata, metadata=self.metadata,
model_name=self.model_name,
) )
collection.set_collection_name(collection_name=collection_name) collection.set_collection_name(collection_name=collection_name)
collection.create_collection(collection_name=collection_name) collection.create_collection(collection_name=collection_name)
@ -546,16 +674,30 @@ class PGVectorDB(VectorDB):
elif overwrite: elif overwrite:
self.delete_collection(collection_name) self.delete_collection(collection_name)
collection = Collection( collection = Collection(
client=self.client,
collection_name=collection_name, collection_name=collection_name,
embedding_function=self.embedding_function, embedding_function=self.embedding_function,
get_or_create=get_or_create, get_or_create=get_or_create,
metadata=self.metadata, metadata=self.metadata,
model_name=self.model_name,
) )
collection.set_collection_name(collection_name=collection_name) collection.set_collection_name(collection_name=collection_name)
collection.create_collection(collection_name=collection_name) collection.create_collection(collection_name=collection_name)
return collection return collection
elif get_or_create: elif get_or_create:
return collection return collection
elif not collection.table_exists(table_name=collection_name):
collection = Collection(
client=self.client,
collection_name=collection_name,
embedding_function=self.embedding_function,
get_or_create=get_or_create,
metadata=self.metadata,
model_name=self.model_name,
)
collection.set_collection_name(collection_name=collection_name)
collection.create_collection(collection_name=collection_name)
return collection
else: else:
raise ValueError(f"Collection {collection_name} already exists.") raise ValueError(f"Collection {collection_name} already exists.")
@ -578,9 +720,13 @@ class PGVectorDB(VectorDB):
f"No collection is specified. Using current active collection {self.active_collection.name}." f"No collection is specified. Using current active collection {self.active_collection.name}."
) )
else: else:
self.active_collection = Collection( if not (self.active_collection and self.active_collection.name == collection_name):
client=self.client, collection_name=collection_name, embedding_function=self.embedding_function self.active_collection = Collection(
) client=self.client,
collection_name=collection_name,
embedding_function=self.embedding_function,
model_name=self.model_name,
)
return self.active_collection return self.active_collection
def delete_collection(self, collection_name: str) -> None: def delete_collection(self, collection_name: str) -> None:
@ -593,16 +739,20 @@ class PGVectorDB(VectorDB):
Returns: Returns:
None None
""" """
self.active_collection.delete_collection(collection_name) if self.active_collection:
self.active_collection.delete_collection(collection_name)
else:
collection = self.get_collection(collection_name)
collection.delete_collection(collection_name)
if self.active_collection and self.active_collection.name == collection_name: if self.active_collection and self.active_collection.name == collection_name:
self.active_collection = None self.active_collection = None
def _batch_insert( def _batch_insert(
self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
): ) -> None:
batch_size = int(PGVECTOR_MAX_BATCH_SIZE) batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16} default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
default_metadatas = [default_metadata] default_metadatas = [default_metadata] * min(batch_size, len(documents))
for i in range(0, len(documents), min(batch_size, len(documents))): for i in range(0, len(documents), min(batch_size, len(documents))):
end_idx = i + min(batch_size, len(documents) - i) end_idx = i + min(batch_size, len(documents) - i)
collection_kwargs = { collection_kwargs = {
@ -715,12 +865,14 @@ class PGVectorDB(VectorDB):
logger.debug(f"Retrieve Docs Results:\n{results}") logger.debug(f"Retrieve Docs Results:\n{results}")
return results return results
def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]: def get_docs_by_ids(
self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
) -> List[Document]:
""" """
Retrieve documents from the collection of the vector database based on the ids. Retrieve documents from the collection of the vector database based on the ids.
Args: Args:
ids: List[ItemID] | A list of document ids. ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None. collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is None. include: List[str] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included. If None, will include ["metadatas", "documents"], ids will always be included.

View File

@ -48,14 +48,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"models to use: ['gpt-35-turbo', 'gpt-35-turbo-0613']\n" "models to use: ['gpt-3.5-turbo-0125']\n"
] ]
} }
], ],
@ -73,7 +73,9 @@
"# a vector database instance\n", "# a vector database instance\n",
"from autogen.retrieve_utils import TEXT_FORMATS\n", "from autogen.retrieve_utils import TEXT_FORMATS\n",
"\n", "\n",
"config_list = autogen.config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n", "config_list = [\n",
" {\"model\": \"gpt-3.5-turbo-0125\", \"api_key\": \"<YOUR_API_KEY>\", \"api_type\": \"openai\"},\n",
"]\n",
"\n", "\n",
"assert len(config_list) > 0\n", "assert len(config_list) > 0\n",
"print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])" "print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
@ -105,7 +107,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Accepted file formats for `docs_path`:\n", "Accepted file formats for `docs_path`:\n",
"['ppt', 'jsonl', 'csv', 'yaml', 'rst', 'htm', 'pdf', 'tsv', 'doc', 'docx', 'pptx', 'msg', 'yml', 'xml', 'md', 'json', 'txt', 'epub', 'org', 'xlsx', 'log', 'html', 'odt', 'rtf']\n" "['odt', 'xml', 'pdf', 'docx', 'html', 'md', 'htm', 'csv', 'rst', 'org', 'ppt', 'doc', 'log', 'json', 'epub', 'jsonl', 'pptx', 'yml', 'xlsx', 'tsv', 'txt', 'yaml', 'msg', 'rtf']\n"
] ]
} }
], ],
@ -118,16 +120,7 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" torch.utils._pytree._register_pytree_node(\n"
]
}
],
"source": [ "source": [
"# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n", "# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n",
"assistant = RetrieveAssistantAgent(\n", "assistant = RetrieveAssistantAgent(\n",
@ -500,7 +493,7 @@
"# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n", "# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n",
"# With human-in-loop, the conversation will continue until the user says \"exit\".\n", "# With human-in-loop, the conversation will continue until the user says \"exit\".\n",
"code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n", "code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n",
"ragproxyagent.initiate_chat(\n", "chat_result = ragproxyagent.initiate_chat(\n",
" assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string=\"spark\"\n", " assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string=\"spark\"\n",
") # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain \"spark\"." ") # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain \"spark\"."
] ]
@ -822,7 +815,7 @@
"assistant.reset()\n", "assistant.reset()\n",
"\n", "\n",
"qa_problem = \"Who is the author of FLAML?\"\n", "qa_problem = \"Who is the author of FLAML?\"\n",
"ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)" "chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)"
] ]
}, },
{ {
@ -1235,7 +1228,7 @@
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n", "# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
"ragproxyagent.human_input_mode = \"ALWAYS\"\n", "ragproxyagent.human_input_mode = \"ALWAYS\"\n",
"code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n", "code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n",
"ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)" "chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)"
] ]
}, },
{ {
@ -1793,7 +1786,7 @@
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n", "# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
"ragproxyagent.human_input_mode = \"ALWAYS\"\n", "ragproxyagent.human_input_mode = \"ALWAYS\"\n",
"qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n", "qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n",
"ragproxyagent.initiate_chat(\n", "chat_result = ragproxyagent.initiate_chat(\n",
" assistant, message=ragproxyagent.message_generator, problem=qa_problem\n", " assistant, message=ragproxyagent.message_generator, problem=qa_problem\n",
") # type \"exit\" to exit the conversation" ") # type \"exit\" to exit the conversation"
] ]
@ -2386,7 +2379,9 @@
" assistant.reset()\n", " assistant.reset()\n",
"\n", "\n",
" qa_problem = questions[i]\n", " qa_problem = questions[i]\n",
" ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=30)" " chat_result = ragproxyagent.initiate_chat(\n",
" assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=30\n",
" )"
] ]
}, },
{ {
@ -2813,7 +2808,9 @@
" assistant.reset()\n", " assistant.reset()\n",
"\n", "\n",
" qa_problem = questions[i]\n", " qa_problem = questions[i]\n",
" ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=10)" " chat_result = ragproxyagent.initiate_chat(\n",
" assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=10\n",
" )"
] ]
} }
], ],
@ -2839,7 +2836,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.13" "version": "3.11.9"
}, },
"skip_test": "Requires interactive usage" "skip_test": "Requires interactive usage"
}, },

View File

@ -189,6 +189,7 @@
" file.write(\"\".join(file_contents))\n", " file.write(\"\".join(file_contents))\n",
" return 0, \"Code modified\"\n", " return 0, \"Code modified\"\n",
"\n", "\n",
"\n",
"@user_proxy.register_for_execution()\n", "@user_proxy.register_for_execution()\n",
"@engineer.register_for_llm(description=\"Create a new file with code.\")\n", "@engineer.register_for_llm(description=\"Create a new file with code.\")\n",
"def create_file_with_code(\n", "def create_file_with_code(\n",

File diff suppressed because one or more lines are too long

View File

@ -14,7 +14,7 @@ with open(os.path.join(here, "autogen/version.py")) as fp:
__version__ = version["__version__"] __version__ = version["__version__"]
install_requires = [ install_requires = [
"openai>=1.3", "openai>=1.23.3",
"diskcache", "diskcache",
"termcolor", "termcolor",
"flaml", "flaml",
@ -35,7 +35,43 @@ jupyter_executor = [
"ipykernel>=6.29.0", "ipykernel>=6.29.0",
] ]
rag = ["sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"] retrieve_chat = ["chromadb", "sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"]
extra_require = {
"test": [
"coverage>=5.3",
"ipykernel",
"nbconvert",
"nbformat",
"pre-commit",
"pytest-asyncio",
"pytest>=6.1.1,<8",
"pandas",
],
"blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": retrieve_chat,
"retrievechat-pgvector": [
*retrieve_chat,
"pgvector>=0.2.5",
"psycopg>=3.1.18",
],
"retrievechat-qdrant": [
*retrieve_chat,
"qdrant_client[fastembed]",
],
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graph": ["networkx", "matplotlib"],
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"cosmosdb": ["azure-cosmos>=4.2.0"],
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
}
setuptools.setup( setuptools.setup(
name="pyautogen", name="pyautogen",
@ -48,34 +84,7 @@ setuptools.setup(
url="https://github.com/microsoft/autogen", url="https://github.com/microsoft/autogen",
packages=setuptools.find_packages(include=["autogen*"], exclude=["test"]), packages=setuptools.find_packages(include=["autogen*"], exclude=["test"]),
install_requires=install_requires, install_requires=install_requires,
extras_require={ extras_require=extra_require,
"test": [
"coverage>=5.3",
"ipykernel",
"nbconvert",
"nbformat",
"pre-commit",
"pytest-asyncio",
"pytest>=6.1.1,<8",
"pandas",
],
"blendsearch": ["flaml[blendsearch]"],
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
"retrievechat": ["chromadb"] + rag,
"retrievechat-pgvector": ["pgvector>=0.2.5", "psycopg>=3.1.18"] + rag,
"retrievechat-qdrant": ["qdrant_client[fastembed]"] + rag,
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
"teachable": ["chromadb"],
"lmm": ["replicate", "pillow"],
"graph": ["networkx", "matplotlib"],
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
"redis": ["redis"],
"cosmosdb": ["azure-cosmos>=4.2.0"],
"websockets": ["websockets>=12.0,<13"],
"jupyter-executor": jupyter_executor,
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
},
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",

View File

@ -9,10 +9,10 @@ from sentence_transformers import SentenceTransformer
from autogen import config_list_from_json from autogen import config_list_from_json
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
from conftest import skip_openai # noqa: E402 from conftest import skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
try: try:
@ -27,14 +27,14 @@ try:
except ImportError: except ImportError:
skip = True skip = True
else: else:
skip = False or skip_openai skip = False
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files") test_dir = os.path.join(os.path.dirname(__file__), "../../..", "test_files")
@pytest.mark.skipif( @pytest.mark.skipif(
skip, skip or skip_openai,
reason="dependency is not installed OR requested to skip", reason="dependency is not installed OR requested to skip",
) )
def test_retrievechat(): def test_retrievechat():
@ -97,34 +97,5 @@ def test_retrievechat():
print(conversations) print(conversations)
@pytest.mark.skipif(
skip,
reason="dependency is not installed",
)
def test_retrieve_config(caplog):
# test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
retrieve_config={
"chunk_token_size": 2000,
"get_or_create": True,
},
)
# Capture the printed content
captured_logs = caplog.records[0]
print(captured_logs)
# Assert on the printed content
assert (
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
in captured_logs.message
)
assert captured_logs.levelname == "WARNING"
if __name__ == "__main__": if __name__ == "__main__":
test_retrievechat() test_retrievechat()
test_retrieve_config()

View File

@ -8,10 +8,10 @@ import pytest
from autogen import config_list_from_json from autogen import config_list_from_json
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
from conftest import skip_openai # noqa: E402 from conftest import skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
try: try:
@ -35,7 +35,7 @@ except ImportError:
else: else:
skip = False or skip_openai skip = False or skip_openai
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files") test_dir = os.path.join(os.path.dirname(__file__), "../../..", "test_files")
@pytest.mark.skipif( @pytest.mark.skipif(

View File

@ -7,10 +7,10 @@ import pytest
import autogen import autogen
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
from conftest import skip_openai # noqa: E402 from conftest import reason, skip_openai # noqa: E402
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402 from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
try: try:
@ -27,12 +27,14 @@ try:
except ImportError: except ImportError:
skip = True skip = True
else: else:
skip = False or skip_openai skip = False
reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason
@pytest.mark.skipif( @pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip, sys.platform in ["darwin", "win32"] or skip or skip_openai,
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", reason=reason,
) )
def test_retrievechat(): def test_retrievechat():
conversations = {} conversations = {}
@ -80,9 +82,9 @@ def test_retrievechat():
@pytest.mark.skipif( @pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip, sys.platform in ["darwin", "win32"] or skip,
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip", reason=reason,
) )
def test_retrieve_config(caplog): def test_retrieve_config():
# test warning message when no docs_path is provided # test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent( ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent", name="ragproxyagent",
@ -93,17 +95,7 @@ def test_retrieve_config(caplog):
"get_or_create": True, "get_or_create": True,
}, },
) )
assert ragproxyagent._docs_path is None
# Capture the printed content
captured_logs = caplog.records[0]
print(captured_logs)
# Assert on the printed content
assert (
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
in captured_logs.message
)
assert captured_logs.levelname == "WARNING"
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,7 @@ import os
import sys import sys
import pytest import pytest
from conftest import reason
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
@ -9,28 +10,33 @@ try:
import pgvector import pgvector
import sentence_transformers import sentence_transformers
from autogen.agentchat.contrib.vectordb.pgvector import PGVector from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB
except ImportError: except ImportError:
skip = True skip = True
else: else:
skip = False skip = False
reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason
@pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip")
@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip,
reason=reason,
)
def test_pgvector(): def test_pgvector():
# test create collection # test create collection
db_config = { db_config = {
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres", "connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
} }
db = PGVector(connection_string=db_config["connection_string"]) db = PGVectorDB(connection_string=db_config["connection_string"])
collection_name = "test_collection" collection_name = "test_collection"
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True) collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name assert collection.name == collection_name
# test_delete_collection # test_delete_collection
db.delete_collection(collection_name) db.delete_collection(collection_name)
pytest.raises(ValueError, db.get_collection, collection_name) assert collection.table_exists(table_name=collection_name) is False
# test more create collection # test more create collection
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False) collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
@ -48,21 +54,24 @@ def test_pgvector():
# test_insert_docs # test_insert_docs
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.insert_docs(docs, collection_name, upsert=False) db.insert_docs(docs, collection_name, upsert=False)
res = db.get_collection(collection_name).get(["1", "2"]) res = db.get_collection(collection_name).get(ids=["1", "2"])
assert res["documents"] == ["doc1", "doc2"] final_results = [result.get("content") for result in res]
assert final_results == ["doc1", "doc2"]
# test_update_docs # test_update_docs
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}] docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
db.update_docs(docs, collection_name) db.update_docs(docs, collection_name)
res = db.get_collection(collection_name).get(["1", "2"]) res = db.get_collection(collection_name).get(["1", "2"])
assert res["documents"] == ["doc11", "doc2"] final_results = [result.get("content") for result in res]
assert final_results == ["doc11", "doc2"]
# test_delete_docs # test_delete_docs
ids = ["1"] ids = ["1"]
collection_name = "test_collection" collection_name = "test_collection"
db.delete_docs(ids, collection_name) db.delete_docs(ids, collection_name)
res = db.get_collection(collection_name).get(ids) res = db.get_collection(collection_name).get(ids)
assert res["documents"] == [] final_results = [result.get("content") for result in res]
assert final_results == []
# test_retrieve_docs # test_retrieve_docs
queries = ["doc2", "doc3"] queries = ["doc2", "doc3"]
@ -70,12 +79,13 @@ def test_pgvector():
res = db.retrieve_docs(queries, collection_name) res = db.retrieve_docs(queries, collection_name)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
print(res)
assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]] assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]]
# test_get_docs_by_ids # test_get_docs_by_ids
res = db.get_docs_by_ids(["1", "2"], collection_name) res = db.get_docs_by_ids(["1", "2"], collection_name)
assert [r["id"] for r in res] == ["2"] # "1" has been deleted assert [r["id"] for r in res] == ["2"] # "1" has been deleted
res = db.get_docs_by_ids(collection_name=collection_name)
assert set([r["id"] for r in res]) == set(["2", "3"]) # All Docs returned
if __name__ == "__main__": if __name__ == "__main__":