mirror of https://github.com/microsoft/autogen.git
2447 fix pgvector tests and notebook (#2455)
* Re-added missing notebook * Test installing postgres * Error handle the connection. * Fixed import. * Fixed import. * Fixed creation of collection without client. * PGVector portion working. OpenAI untested. * Fixed prints. * Added output. * Fixed pre-commits. * Run pgvector notebook * Improve efficiency of get_collection * Fix delete_collection * Fixed issues with pytests and validated functions. * Validated pytests. * Fixed pre-commits * Separated extra_requires to allow more logic. Retrieve_chat base dependencies included on pgvector and qdrant. * Fixed extra newline. * Added username and password fields. * URL Encode the connection string parameters to support symbols like % * Fixed pre-commits. * Added pgvector service * pgvector doesn't have health intervals. * Switched to colon based key values. * Run on Ubuntu only. Linux is only option with container service support. * Using default credentials instead. * Fix postgres setup * Fix postgres setup * Don't skip tests on win and mac * Fix command error * Try apt install postgresql * Assert table does not exist when deleted. * Raise value error on a empty list or None value provided for IDs * pre-commit * Add install pgvector * Add install pgvector * Reorg test files, create a separate job for test pgvector * Fix format * Fix env format * Simplify job name, enable test_retrieve_config * Fix test_retrieve_config * Corrected behavior for get_docs_by_ids with no ids returning all docs. * Corrected behavior for get_docs_by_ids with no ids returning all docs. * Fixed pre-commits. * Added return values for all functions. * Validated distance search is implemented correctly. * Validated all pytests * Removed print. * Added default clause. * Make ids optional * Fix test, make it more robust * Bump version of openai for the vector_store support * Added support for choosing the sentence transformer model. * Added error handling for model name entered. * Updated model info. * Added model_name db_config param. * pre-commit fixes and last link fix. * Use secrets password. * fix: link fixed * updated tests * Updated config_list. * pre-commit fix. * Added chat_result to all output. Unable to re-run notebooks. * Pre-commit fix detected this requirement. * Fix python 3.8 and 3.9 not supported for macos * Fix python 3.8 and 3.9 not supported for macos * Fix format * Reran notebook with MetaLlama3Instruct7BQ4_k_M * added gpt model. * Reran notebook --------- Co-authored-by: Li Jiang <bnujli@gmail.com> Co-authored-by: Hk669 <hrushi669@gmail.com>
This commit is contained in:
parent
600bd3f2fe
commit
1b8d65df0a
|
@ -24,6 +24,21 @@ jobs:
|
||||||
python-version: ["3.10"]
|
python-version: ["3.10"]
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
environment: openai1
|
environment: openai1
|
||||||
|
services:
|
||||||
|
pgvector:
|
||||||
|
image: ankane/pgvector
|
||||||
|
env:
|
||||||
|
POSTGRES_DB: postgres
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
|
||||||
|
POSTGRES_HOST_AUTH_METHOD: trust
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
steps:
|
steps:
|
||||||
# checkout to pr branch
|
# checkout to pr branch
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
@ -41,15 +56,10 @@ jobs:
|
||||||
pip install -e .
|
pip install -e .
|
||||||
python -c "import autogen"
|
python -c "import autogen"
|
||||||
pip install coverage pytest-asyncio
|
pip install coverage pytest-asyncio
|
||||||
- name: Install PostgreSQL
|
|
||||||
run: |
|
|
||||||
sudo apt install postgresql -y
|
|
||||||
- name: Start PostgreSQL service
|
|
||||||
run: sudo service postgresql start
|
|
||||||
- name: Install packages for test when needed
|
- name: Install packages for test when needed
|
||||||
run: |
|
run: |
|
||||||
pip install docker
|
pip install docker
|
||||||
pip install -e .[retrievechat-qdrant,retrievechat-pgvector]
|
pip install -e .[retrievechat,retrievechat-qdrant,retrievechat-pgvector]
|
||||||
- name: Coverage
|
- name: Coverage
|
||||||
env:
|
env:
|
||||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||||
|
@ -57,13 +67,14 @@ jobs:
|
||||||
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
|
AZURE_OPENAI_API_BASE: ${{ secrets.AZURE_OPENAI_API_BASE }}
|
||||||
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
|
OAI_CONFIG_LIST: ${{ secrets.OAI_CONFIG_LIST }}
|
||||||
run: |
|
run: |
|
||||||
coverage run -a -m pytest test/agentchat/contrib/test_retrievechat.py::test_retrievechat test/agentchat/contrib/test_qdrant_retrievechat.py::test_retrievechat test/agentchat/contrib/test_pgvector_retrievechat.py::test_retrievechat
|
coverage run -a -m pytest -k test_retrievechat test/agentchat/contrib/retrievechat
|
||||||
coverage xml
|
coverage xml
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
with:
|
with:
|
||||||
file: ./coverage.xml
|
file: ./coverage.xml
|
||||||
flags: unittests
|
flags: unittests
|
||||||
|
|
||||||
CompressionTest:
|
CompressionTest:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
|
|
|
@ -27,7 +27,7 @@ jobs:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
os: [macos-latest, windows-2019]
|
||||||
python-version: ["3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11"]
|
||||||
exclude:
|
exclude:
|
||||||
- os: macos-latest
|
- os: macos-latest
|
||||||
|
@ -45,20 +45,10 @@ jobs:
|
||||||
- name: Install qdrant_client when python-version is 3.10
|
- name: Install qdrant_client when python-version is 3.10
|
||||||
if: matrix.python-version == '3.10'
|
if: matrix.python-version == '3.10'
|
||||||
run: |
|
run: |
|
||||||
pip install .[retrievechat-qdrant]
|
pip install -e .[retrievechat-qdrant]
|
||||||
- name: Install unstructured when python-version is 3.9 and on linux
|
- name: Install packages and dependencies for RetrieveChat
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update
|
pip install -e .[retrievechat]
|
||||||
sudo apt-get install -y tesseract-ocr poppler-utils
|
|
||||||
pip install unstructured[all-docs]==0.13.0
|
|
||||||
- name: Install and Start PostgreSQL
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
run: |
|
|
||||||
sudo apt install postgresql -y
|
|
||||||
sudo service postgresql start
|
|
||||||
- name: Install packages and dependencies for PGVector
|
|
||||||
run: |
|
|
||||||
pip install -e .[retrievechat-pgvector]
|
|
||||||
- name: Set AUTOGEN_USE_DOCKER based on OS
|
- name: Set AUTOGEN_USE_DOCKER based on OS
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
|
@ -68,7 +58,69 @@ jobs:
|
||||||
- name: Coverage
|
- name: Coverage
|
||||||
run: |
|
run: |
|
||||||
pip install coverage>=5.3
|
pip install coverage>=5.3
|
||||||
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/test_retrievechat.py test/agentchat/contrib/test_qdrant_retrievechat.py test/agentchat/contrib/vectordb --skip-openai
|
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat test/agentchat/contrib/vectordb --skip-openai
|
||||||
|
coverage xml
|
||||||
|
- name: Upload coverage to Codecov
|
||||||
|
uses: codecov/codecov-action@v3
|
||||||
|
with:
|
||||||
|
file: ./coverage.xml
|
||||||
|
flags: unittests
|
||||||
|
|
||||||
|
RetrieveChatTest-Ubuntu:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.9", "3.10", "3.11"]
|
||||||
|
services:
|
||||||
|
pgvector:
|
||||||
|
image: ankane/pgvector
|
||||||
|
env:
|
||||||
|
POSTGRES_DB: postgres
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: ${{ secrets.POSTGRES_PASSWORD }}
|
||||||
|
POSTGRES_HOST_AUTH_METHOD: trust
|
||||||
|
options: >-
|
||||||
|
--health-cmd pg_isready
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
--health-retries 5
|
||||||
|
ports:
|
||||||
|
- 5432:5432
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
- name: Install packages and dependencies for all tests
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip wheel
|
||||||
|
pip install pytest
|
||||||
|
- name: Install qdrant_client when python-version is 3.10
|
||||||
|
if: matrix.python-version == '3.10'
|
||||||
|
run: |
|
||||||
|
pip install -e .[retrievechat-qdrant]
|
||||||
|
- name: Install pgvector when on linux
|
||||||
|
run: |
|
||||||
|
pip install -e .[retrievechat-pgvector]
|
||||||
|
- name: Install unstructured when python-version is 3.9 and on linux
|
||||||
|
if: matrix.python-version == '3.9'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y tesseract-ocr poppler-utils
|
||||||
|
pip install unstructured[all-docs]==0.13.0
|
||||||
|
- name: Install packages and dependencies for RetrieveChat
|
||||||
|
run: |
|
||||||
|
pip install -e .[retrievechat]
|
||||||
|
- name: Set AUTOGEN_USE_DOCKER based on OS
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
echo "AUTOGEN_USE_DOCKER=False" >> $GITHUB_ENV
|
||||||
|
- name: Coverage
|
||||||
|
run: |
|
||||||
|
pip install coverage>=5.3
|
||||||
|
coverage run -a -m pytest test/test_retrieve_utils.py test/agentchat/contrib/retrievechat test/agentchat/contrib/vectordb --skip-openai
|
||||||
coverage xml
|
coverage xml
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
uses: codecov/codecov-action@v3
|
uses: codecov/codecov-action@v3
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import urllib.parse
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -33,7 +34,8 @@ class Collection:
|
||||||
embedding_function (Callable): The embedding function used to generate the vector representation.
|
embedding_function (Callable): The embedding function used to generate the vector representation.
|
||||||
metadata (Optional[dict]): The metadata of the collection.
|
metadata (Optional[dict]): The metadata of the collection.
|
||||||
get_or_create (Optional): The flag indicating whether to get or create the collection.
|
get_or_create (Optional): The flag indicating whether to get or create the collection.
|
||||||
|
model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
|
||||||
|
https://huggingface.co/models?library=sentence-transformers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -43,6 +45,7 @@ class Collection:
|
||||||
embedding_function: Callable = None,
|
embedding_function: Callable = None,
|
||||||
metadata=None,
|
metadata=None,
|
||||||
get_or_create=None,
|
get_or_create=None,
|
||||||
|
model_name="all-MiniLM-L6-v2",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Collection object.
|
Initialize the Collection object.
|
||||||
|
@ -53,46 +56,76 @@ class Collection:
|
||||||
embedding_function: The embedding function used to generate the vector representation.
|
embedding_function: The embedding function used to generate the vector representation.
|
||||||
metadata: The metadata of the collection.
|
metadata: The metadata of the collection.
|
||||||
get_or_create: The flag indicating whether to get or create the collection.
|
get_or_create: The flag indicating whether to get or create the collection.
|
||||||
|
model_name: | Sentence embedding model to use. Models can be chosen from:
|
||||||
|
https://huggingface.co/models?library=sentence-transformers
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
|
self.model_name = model_name
|
||||||
self.name = self.set_collection_name(collection_name)
|
self.name = self.set_collection_name(collection_name)
|
||||||
self.require_embeddings_or_documents = False
|
self.require_embeddings_or_documents = False
|
||||||
self.ids = []
|
self.ids = []
|
||||||
self.embedding_function = (
|
try:
|
||||||
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function
|
self.embedding_function = (
|
||||||
)
|
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Validate the model name entered: {self.model_name} "
|
||||||
|
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
||||||
self.documents = ""
|
self.documents = ""
|
||||||
self.get_or_create = get_or_create
|
self.get_or_create = get_or_create
|
||||||
|
|
||||||
def set_collection_name(self, collection_name):
|
def set_collection_name(self, collection_name) -> str:
|
||||||
name = re.sub("-", "_", collection_name)
|
name = re.sub("-", "_", collection_name)
|
||||||
self.name = name
|
self.name = name
|
||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def add(self, ids: List[ItemID], embeddings: List, metadatas: List, documents: List):
|
def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metadatas: List = None) -> None:
|
||||||
"""
|
"""
|
||||||
Add documents to the collection.
|
Add documents to the collection.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids (List[ItemID]): A list of document IDs.
|
ids (List[ItemID]): A list of document IDs.
|
||||||
embeddings (List): A list of document embeddings.
|
embeddings (List): A list of document embeddings. Optional
|
||||||
metadatas (List): A list of document metadatas.
|
metadatas (List): A list of document metadatas. Optional
|
||||||
documents (List): A list of documents.
|
documents (List): A list of documents.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
cursor = self.client.cursor()
|
cursor = self.client.cursor()
|
||||||
|
|
||||||
sql_values = []
|
sql_values = []
|
||||||
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
|
if embeddings is not None and metadatas is not None:
|
||||||
sql_values.append((doc_id, embedding, metadata, document))
|
for doc_id, embedding, metadata, document in zip(ids, embeddings, metadatas, documents):
|
||||||
sql_string = f"INSERT INTO {self.name} (id, embedding, metadata, document) " f"VALUES (%s, %s, %s, %s);"
|
metadata = re.sub("'", '"', str(metadata))
|
||||||
|
sql_values.append((doc_id, embedding, metadata, document))
|
||||||
|
sql_string = (
|
||||||
|
f"INSERT INTO {self.name} (id, embedding, metadatas, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
|
||||||
|
)
|
||||||
|
elif embeddings is not None:
|
||||||
|
for doc_id, embedding, document in zip(ids, embeddings, documents):
|
||||||
|
sql_values.append((doc_id, embedding, document))
|
||||||
|
sql_string = f"INSERT INTO {self.name} (id, embedding, documents) " f"VALUES (%s, %s, %s);\n"
|
||||||
|
elif metadatas is not None:
|
||||||
|
for doc_id, metadata, document in zip(ids, metadatas, documents):
|
||||||
|
metadata = re.sub("'", '"', str(metadata))
|
||||||
|
embedding = self.embedding_function.encode(document)
|
||||||
|
sql_values.append((doc_id, metadata, embedding, document))
|
||||||
|
sql_string = (
|
||||||
|
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for doc_id, document in zip(ids, documents):
|
||||||
|
embedding = self.embedding_function.encode(document)
|
||||||
|
sql_values.append((doc_id, document, embedding))
|
||||||
|
sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
|
||||||
|
logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
|
||||||
cursor.executemany(sql_string, sql_values)
|
cursor.executemany(sql_string, sql_values)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
@ -155,7 +188,7 @@ class Collection:
|
||||||
cursor.executemany(sql_string, sql_values)
|
cursor.executemany(sql_string, sql_values)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def count(self):
|
def count(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get the total number of documents in the collection.
|
Get the total number of documents in the collection.
|
||||||
|
|
||||||
|
@ -173,7 +206,32 @@ class Collection:
|
||||||
total = None
|
total = None
|
||||||
return total
|
return total
|
||||||
|
|
||||||
def get(self, ids=None, include=None, where=None, limit=None, offset=None):
|
def table_exists(self, table_name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a table exists in the PostgreSQL database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): The name of the table to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the table exists, False otherwise.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cursor = self.client.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_name = %s
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
(table_name,),
|
||||||
|
)
|
||||||
|
exists = cursor.fetchone()[0]
|
||||||
|
return exists
|
||||||
|
|
||||||
|
def get(self, ids=None, include=None, where=None, limit=None, offset=None) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Retrieve documents from the collection.
|
Retrieve documents from the collection.
|
||||||
|
|
||||||
|
@ -188,39 +246,65 @@ class Collection:
|
||||||
List: The retrieved documents.
|
List: The retrieved documents.
|
||||||
"""
|
"""
|
||||||
cursor = self.client.cursor()
|
cursor = self.client.cursor()
|
||||||
|
|
||||||
|
# Initialize variables for query components
|
||||||
|
select_clause = "SELECT id, metadatas, documents, embedding"
|
||||||
|
from_clause = f"FROM {self.name}"
|
||||||
|
where_clause = ""
|
||||||
|
limit_clause = ""
|
||||||
|
offset_clause = ""
|
||||||
|
|
||||||
|
# Handle include clause
|
||||||
if include:
|
if include:
|
||||||
query = f'SELECT (id, {", ".join(map(str, include))}, embedding) FROM {self.name}'
|
select_clause = f"SELECT id, {', '.join(include)}, embedding"
|
||||||
else:
|
|
||||||
query = f"SELECT * FROM {self.name}"
|
# Handle where clause
|
||||||
if ids:
|
if ids:
|
||||||
query = f"{query} WHERE id IN {ids}"
|
where_clause = f"WHERE id IN ({', '.join(['%s' for _ in ids])})"
|
||||||
elif where:
|
elif where:
|
||||||
query = f"{query} WHERE {where}"
|
where_clause = f"WHERE {where}"
|
||||||
if offset:
|
|
||||||
query = f"{query} OFFSET {offset}"
|
# Handle limit and offset clauses
|
||||||
if limit:
|
if limit:
|
||||||
query = f"{query} LIMIT {limit}"
|
limit_clause = "LIMIT %s"
|
||||||
retreived_documents = []
|
if offset:
|
||||||
|
offset_clause = "OFFSET %s"
|
||||||
|
|
||||||
|
# Construct the full query
|
||||||
|
query = f"{select_clause} {from_clause} {where_clause} {limit_clause} {offset_clause}"
|
||||||
|
|
||||||
|
retrieved_documents = []
|
||||||
try:
|
try:
|
||||||
cursor.execute(query)
|
# Execute the query with the appropriate values
|
||||||
|
if ids is not None:
|
||||||
|
cursor.execute(query, ids)
|
||||||
|
else:
|
||||||
|
query_params = []
|
||||||
|
if limit:
|
||||||
|
query_params.append(limit)
|
||||||
|
if offset:
|
||||||
|
query_params.append(offset)
|
||||||
|
cursor.execute(query, query_params)
|
||||||
|
|
||||||
retrieval = cursor.fetchall()
|
retrieval = cursor.fetchall()
|
||||||
for retrieved_document in retrieval:
|
for retrieved_document in retrieval:
|
||||||
retreived_documents.append(
|
retrieved_documents.append(
|
||||||
Document(
|
Document(
|
||||||
id=retrieved_document[0][0],
|
id=retrieved_document[0].strip(),
|
||||||
metadata=retrieved_document[0][1],
|
metadata=retrieved_document[1],
|
||||||
content=retrieved_document[0][2],
|
content=retrieved_document[2],
|
||||||
embedding=retrieved_document[0][3],
|
embedding=retrieved_document[3],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn):
|
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
|
||||||
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead.")
|
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
|
||||||
self.create_collection(collection_name=self.name)
|
self.create_collection(collection_name=self.name)
|
||||||
logger.info(f"Created table {self.name}")
|
logger.info(f"Created table {self.name}")
|
||||||
cursor.close()
|
|
||||||
return retreived_documents
|
|
||||||
|
|
||||||
def update(self, ids: List, embeddings: List, metadatas: List, documents: List):
|
cursor.close()
|
||||||
|
return retrieved_documents
|
||||||
|
|
||||||
|
def update(self, ids: List, embeddings: List, metadatas: List, documents: List) -> None:
|
||||||
"""
|
"""
|
||||||
Update documents in the collection.
|
Update documents in the collection.
|
||||||
|
|
||||||
|
@ -300,6 +384,7 @@ class Collection:
|
||||||
n_results: int = 10,
|
n_results: int = 10,
|
||||||
distance_type: str = "euclidean",
|
distance_type: str = "euclidean",
|
||||||
distance_threshold: float = -1,
|
distance_threshold: float = -1,
|
||||||
|
include_embedding: bool = False,
|
||||||
) -> QueryResults:
|
) -> QueryResults:
|
||||||
"""
|
"""
|
||||||
Query documents in the collection.
|
Query documents in the collection.
|
||||||
|
@ -310,21 +395,25 @@ class Collection:
|
||||||
n_results (int): The maximum number of results to return.
|
n_results (int): The maximum number of results to return.
|
||||||
distance_type (Optional[str]): Distance search type - euclidean or cosine
|
distance_type (Optional[str]): Distance search type - euclidean or cosine
|
||||||
distance_threshold (Optional[float]): Distance threshold to limit searches
|
distance_threshold (Optional[float]): Distance threshold to limit searches
|
||||||
|
include_embedding (Optional[bool]): Include embedding values in QueryResults
|
||||||
Returns:
|
Returns:
|
||||||
QueryResults: The query results.
|
QueryResults: The query results.
|
||||||
"""
|
"""
|
||||||
if collection_name:
|
if collection_name:
|
||||||
self.name = collection_name
|
self.name = collection_name
|
||||||
|
|
||||||
|
clause = "ORDER BY"
|
||||||
if distance_threshold == -1:
|
if distance_threshold == -1:
|
||||||
distance_threshold = ""
|
distance_threshold = ""
|
||||||
|
clause = "ORDER BY"
|
||||||
elif distance_threshold > 0:
|
elif distance_threshold > 0:
|
||||||
distance_threshold = f"< {distance_threshold}"
|
distance_threshold = f"< {distance_threshold}"
|
||||||
|
clause = "WHERE"
|
||||||
|
|
||||||
cursor = self.client.cursor()
|
cursor = self.client.cursor()
|
||||||
results = []
|
results = []
|
||||||
for query in query_texts:
|
for query_text in query_texts:
|
||||||
vector = self.embedding_function.encode(query, convert_to_tensor=False).tolist()
|
vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist()
|
||||||
if distance_type.lower() == "cosine":
|
if distance_type.lower() == "cosine":
|
||||||
index_function = "<=>"
|
index_function = "<=>"
|
||||||
elif distance_type.lower() == "euclidean":
|
elif distance_type.lower() == "euclidean":
|
||||||
|
@ -333,15 +422,16 @@ class Collection:
|
||||||
index_function = "<#>"
|
index_function = "<#>"
|
||||||
else:
|
else:
|
||||||
index_function = "<->"
|
index_function = "<->"
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
f"SELECT id, documents, embedding, metadatas FROM {self.name}\n"
|
f"SELECT id, documents, embedding, metadatas "
|
||||||
f"ORDER BY embedding {index_function} '{str(vector)}'::vector {distance_threshold}\n"
|
f"FROM {self.name} "
|
||||||
|
f"{clause} embedding {index_function} '{str(vector)}' {distance_threshold} "
|
||||||
f"LIMIT {n_results}"
|
f"LIMIT {n_results}"
|
||||||
)
|
)
|
||||||
cursor.execute(query)
|
cursor.execute(query)
|
||||||
|
result = []
|
||||||
for row in cursor.fetchall():
|
for row in cursor.fetchall():
|
||||||
fetched_document = Document(id=row[0], content=row[1], embedding=row[2], metadata=row[3])
|
fetched_document = Document(id=row[0].strip(), content=row[1], embedding=row[2], metadata=row[3])
|
||||||
fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
|
fetched_document_array = self.convert_string_to_array(array_string=fetched_document.get("embedding"))
|
||||||
if distance_type.lower() == "cosine":
|
if distance_type.lower() == "cosine":
|
||||||
distance = self.cosine_distance(fetched_document_array, vector)
|
distance = self.cosine_distance(fetched_document_array, vector)
|
||||||
|
@ -351,9 +441,11 @@ class Collection:
|
||||||
distance = self.inner_product_distance(fetched_document_array, vector)
|
distance = self.inner_product_distance(fetched_document_array, vector)
|
||||||
else:
|
else:
|
||||||
distance = self.euclidean_distance(fetched_document_array, vector)
|
distance = self.euclidean_distance(fetched_document_array, vector)
|
||||||
results.append((fetched_document, distance))
|
if not include_embedding:
|
||||||
|
fetched_document = Document(id=row[0].strip(), content=row[1], metadata=row[3])
|
||||||
|
result.append((fetched_document, distance))
|
||||||
|
results.append(result)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
results = [results]
|
|
||||||
logger.debug(f"Query Results: {results}")
|
logger.debug(f"Query Results: {results}")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@ -375,7 +467,7 @@ class Collection:
|
||||||
array = [float(num) for num in array_string.split()]
|
array = [float(num) for num in array_string.split()]
|
||||||
return array
|
return array
|
||||||
|
|
||||||
def modify(self, metadata, collection_name: str = None):
|
def modify(self, metadata, collection_name: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Modify metadata for the collection.
|
Modify metadata for the collection.
|
||||||
|
|
||||||
|
@ -394,7 +486,7 @@ class Collection:
|
||||||
)
|
)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def delete(self, ids: List[ItemID], collection_name: str = None):
|
def delete(self, ids: List[ItemID], collection_name: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Delete documents from the collection.
|
Delete documents from the collection.
|
||||||
|
|
||||||
|
@ -408,10 +500,11 @@ class Collection:
|
||||||
if collection_name:
|
if collection_name:
|
||||||
self.name = collection_name
|
self.name = collection_name
|
||||||
cursor = self.client.cursor()
|
cursor = self.client.cursor()
|
||||||
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({ids});")
|
id_placeholders = ", ".join(["%s" for _ in ids])
|
||||||
|
cursor.execute(f"DELETE FROM {self.name} WHERE id IN ({id_placeholders});", ids)
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str = None):
|
def delete_collection(self, collection_name: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Delete the entire collection.
|
Delete the entire collection.
|
||||||
|
|
||||||
|
@ -427,7 +520,7 @@ class Collection:
|
||||||
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
|
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
def create_collection(self, collection_name: str = None):
|
def create_collection(self, collection_name: str = None) -> None:
|
||||||
"""
|
"""
|
||||||
Create a new collection.
|
Create a new collection.
|
||||||
|
|
||||||
|
@ -468,9 +561,12 @@ class PGVectorDB(VectorDB):
|
||||||
host: str = None,
|
host: str = None,
|
||||||
port: int = None,
|
port: int = None,
|
||||||
dbname: str = None,
|
dbname: str = None,
|
||||||
|
username: str = None,
|
||||||
|
password: str = None,
|
||||||
connect_timeout: int = 10,
|
connect_timeout: int = 10,
|
||||||
embedding_function: Callable = None,
|
embedding_function: Callable = None,
|
||||||
metadata: dict = None,
|
metadata: dict = None,
|
||||||
|
model_name: str = "all-MiniLM-L6-v2",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the vector database.
|
Initialize the vector database.
|
||||||
|
@ -482,6 +578,8 @@ class PGVectorDB(VectorDB):
|
||||||
host: str | The host to connect to. Default is None.
|
host: str | The host to connect to. Default is None.
|
||||||
port: int | The port to connect to. Default is None.
|
port: int | The port to connect to. Default is None.
|
||||||
dbname: str | The database name to connect to. Default is None.
|
dbname: str | The database name to connect to. Default is None.
|
||||||
|
username: str | The database username to use. Default is None.
|
||||||
|
password: str | The database user password to use. Default is None.
|
||||||
connect_timeout: int | The timeout to set for the connection. Default is 10.
|
connect_timeout: int | The timeout to set for the connection. Default is 10.
|
||||||
embedding_function: Callable | The embedding function used to generate the vector representation
|
embedding_function: Callable | The embedding function used to generate the vector representation
|
||||||
of the documents. Default is None.
|
of the documents. Default is None.
|
||||||
|
@ -489,20 +587,48 @@ class PGVectorDB(VectorDB):
|
||||||
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
|
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
|
||||||
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
|
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
|
||||||
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
|
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
|
||||||
kwargs: dict | Additional keyword arguments.
|
model_name: str | Sentence embedding model to use. Models can be chosen from:
|
||||||
|
https://huggingface.co/models?library=sentence-transformers
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
if connection_string:
|
try:
|
||||||
self.client = psycopg.connect(conninfo=connection_string, autocommit=True)
|
if connection_string:
|
||||||
elif host and port and dbname:
|
parsed_connection = urllib.parse.urlparse(connection_string)
|
||||||
self.client = psycopg.connect(
|
encoded_username = urllib.parse.quote(parsed_connection.username, safe="")
|
||||||
host=host, port=port, dbname=dbname, connect_timeout=connect_timeout, autocommit=True
|
encoded_password = urllib.parse.quote(parsed_connection.password, safe="")
|
||||||
|
encoded_host = urllib.parse.quote(parsed_connection.hostname, safe="")
|
||||||
|
encoded_database = urllib.parse.quote(parsed_connection.path[1:], safe="")
|
||||||
|
connection_string_encoded = (
|
||||||
|
f"{parsed_connection.scheme}://{encoded_username}:{encoded_password}"
|
||||||
|
f"@{encoded_host}:{parsed_connection.port}/{encoded_database}"
|
||||||
|
)
|
||||||
|
self.client = psycopg.connect(conninfo=connection_string_encoded, autocommit=True)
|
||||||
|
elif host and port and dbname:
|
||||||
|
self.client = psycopg.connect(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
dbname=dbname,
|
||||||
|
username=username,
|
||||||
|
password=password,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
autocommit=True,
|
||||||
|
)
|
||||||
|
except psycopg.Error as e:
|
||||||
|
logger.error("Error connecting to the database: ", e)
|
||||||
|
raise e
|
||||||
|
self.model_name = model_name
|
||||||
|
try:
|
||||||
|
self.embedding_function = (
|
||||||
|
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
|
||||||
)
|
)
|
||||||
self.embedding_function = (
|
except Exception as e:
|
||||||
SentenceTransformer("all-MiniLM-L6-v2") if embedding_function is None else embedding_function
|
logger.error(
|
||||||
)
|
f"Validate the model name entered: {self.model_name} "
|
||||||
|
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
self.client.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||||
register_vector(self.client)
|
register_vector(self.client)
|
||||||
|
@ -535,10 +661,12 @@ class PGVectorDB(VectorDB):
|
||||||
collection = None
|
collection = None
|
||||||
if collection is None:
|
if collection is None:
|
||||||
collection = Collection(
|
collection = Collection(
|
||||||
|
client=self.client,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
get_or_create=get_or_create,
|
get_or_create=get_or_create,
|
||||||
metadata=self.metadata,
|
metadata=self.metadata,
|
||||||
|
model_name=self.model_name,
|
||||||
)
|
)
|
||||||
collection.set_collection_name(collection_name=collection_name)
|
collection.set_collection_name(collection_name=collection_name)
|
||||||
collection.create_collection(collection_name=collection_name)
|
collection.create_collection(collection_name=collection_name)
|
||||||
|
@ -546,16 +674,30 @@ class PGVectorDB(VectorDB):
|
||||||
elif overwrite:
|
elif overwrite:
|
||||||
self.delete_collection(collection_name)
|
self.delete_collection(collection_name)
|
||||||
collection = Collection(
|
collection = Collection(
|
||||||
|
client=self.client,
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
embedding_function=self.embedding_function,
|
embedding_function=self.embedding_function,
|
||||||
get_or_create=get_or_create,
|
get_or_create=get_or_create,
|
||||||
metadata=self.metadata,
|
metadata=self.metadata,
|
||||||
|
model_name=self.model_name,
|
||||||
)
|
)
|
||||||
collection.set_collection_name(collection_name=collection_name)
|
collection.set_collection_name(collection_name=collection_name)
|
||||||
collection.create_collection(collection_name=collection_name)
|
collection.create_collection(collection_name=collection_name)
|
||||||
return collection
|
return collection
|
||||||
elif get_or_create:
|
elif get_or_create:
|
||||||
return collection
|
return collection
|
||||||
|
elif not collection.table_exists(table_name=collection_name):
|
||||||
|
collection = Collection(
|
||||||
|
client=self.client,
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
get_or_create=get_or_create,
|
||||||
|
metadata=self.metadata,
|
||||||
|
model_name=self.model_name,
|
||||||
|
)
|
||||||
|
collection.set_collection_name(collection_name=collection_name)
|
||||||
|
collection.create_collection(collection_name=collection_name)
|
||||||
|
return collection
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Collection {collection_name} already exists.")
|
raise ValueError(f"Collection {collection_name} already exists.")
|
||||||
|
|
||||||
|
@ -578,9 +720,13 @@ class PGVectorDB(VectorDB):
|
||||||
f"No collection is specified. Using current active collection {self.active_collection.name}."
|
f"No collection is specified. Using current active collection {self.active_collection.name}."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.active_collection = Collection(
|
if not (self.active_collection and self.active_collection.name == collection_name):
|
||||||
client=self.client, collection_name=collection_name, embedding_function=self.embedding_function
|
self.active_collection = Collection(
|
||||||
)
|
client=self.client,
|
||||||
|
collection_name=collection_name,
|
||||||
|
embedding_function=self.embedding_function,
|
||||||
|
model_name=self.model_name,
|
||||||
|
)
|
||||||
return self.active_collection
|
return self.active_collection
|
||||||
|
|
||||||
def delete_collection(self, collection_name: str) -> None:
|
def delete_collection(self, collection_name: str) -> None:
|
||||||
|
@ -593,16 +739,20 @@ class PGVectorDB(VectorDB):
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
self.active_collection.delete_collection(collection_name)
|
if self.active_collection:
|
||||||
|
self.active_collection.delete_collection(collection_name)
|
||||||
|
else:
|
||||||
|
collection = self.get_collection(collection_name)
|
||||||
|
collection.delete_collection(collection_name)
|
||||||
if self.active_collection and self.active_collection.name == collection_name:
|
if self.active_collection and self.active_collection.name == collection_name:
|
||||||
self.active_collection = None
|
self.active_collection = None
|
||||||
|
|
||||||
def _batch_insert(
|
def _batch_insert(
|
||||||
self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
|
self, collection: Collection, embeddings=None, ids=None, metadatas=None, documents=None, upsert=False
|
||||||
):
|
) -> None:
|
||||||
batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
|
batch_size = int(PGVECTOR_MAX_BATCH_SIZE)
|
||||||
default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
default_metadata = {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
|
||||||
default_metadatas = [default_metadata]
|
default_metadatas = [default_metadata] * min(batch_size, len(documents))
|
||||||
for i in range(0, len(documents), min(batch_size, len(documents))):
|
for i in range(0, len(documents), min(batch_size, len(documents))):
|
||||||
end_idx = i + min(batch_size, len(documents) - i)
|
end_idx = i + min(batch_size, len(documents) - i)
|
||||||
collection_kwargs = {
|
collection_kwargs = {
|
||||||
|
@ -715,12 +865,14 @@ class PGVectorDB(VectorDB):
|
||||||
logger.debug(f"Retrieve Docs Results:\n{results}")
|
logger.debug(f"Retrieve Docs Results:\n{results}")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_docs_by_ids(self, ids: List[ItemID], collection_name: str = None, include=None, **kwargs) -> List[Document]:
|
def get_docs_by_ids(
|
||||||
|
self, ids: List[ItemID] = None, collection_name: str = None, include=None, **kwargs
|
||||||
|
) -> List[Document]:
|
||||||
"""
|
"""
|
||||||
Retrieve documents from the collection of the vector database based on the ids.
|
Retrieve documents from the collection of the vector database based on the ids.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ids: List[ItemID] | A list of document ids.
|
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
||||||
collection_name: str | The name of the collection. Default is None.
|
collection_name: str | The name of the collection. Default is None.
|
||||||
include: List[str] | The fields to include. Default is None.
|
include: List[str] | The fields to include. Default is None.
|
||||||
If None, will include ["metadatas", "documents"], ids will always be included.
|
If None, will include ["metadatas", "documents"], ids will always be included.
|
||||||
|
|
|
@ -48,14 +48,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 4,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"models to use: ['gpt-35-turbo', 'gpt-35-turbo-0613']\n"
|
"models to use: ['gpt-3.5-turbo-0125']\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -73,7 +73,9 @@
|
||||||
"# a vector database instance\n",
|
"# a vector database instance\n",
|
||||||
"from autogen.retrieve_utils import TEXT_FORMATS\n",
|
"from autogen.retrieve_utils import TEXT_FORMATS\n",
|
||||||
"\n",
|
"\n",
|
||||||
"config_list = autogen.config_list_from_json(env_or_file=\"OAI_CONFIG_LIST\")\n",
|
"config_list = [\n",
|
||||||
|
" {\"model\": \"gpt-3.5-turbo-0125\", \"api_key\": \"<YOUR_API_KEY>\", \"api_type\": \"openai\"},\n",
|
||||||
|
"]\n",
|
||||||
"\n",
|
"\n",
|
||||||
"assert len(config_list) > 0\n",
|
"assert len(config_list) > 0\n",
|
||||||
"print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
|
"print(\"models to use: \", [config_list[i][\"model\"] for i in range(len(config_list))])"
|
||||||
|
@ -105,7 +107,7 @@
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Accepted file formats for `docs_path`:\n",
|
"Accepted file formats for `docs_path`:\n",
|
||||||
"['ppt', 'jsonl', 'csv', 'yaml', 'rst', 'htm', 'pdf', 'tsv', 'doc', 'docx', 'pptx', 'msg', 'yml', 'xml', 'md', 'json', 'txt', 'epub', 'org', 'xlsx', 'log', 'html', 'odt', 'rtf']\n"
|
"['odt', 'xml', 'pdf', 'docx', 'html', 'md', 'htm', 'csv', 'rst', 'org', 'ppt', 'doc', 'log', 'json', 'epub', 'jsonl', 'pptx', 'yml', 'xlsx', 'tsv', 'txt', 'yaml', 'msg', 'rtf']\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -118,16 +120,7 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 3,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/home/lijiang1/anaconda3/envs/autogen/lib/python3.10/site-packages/transformers/utils/generic.py:311: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
|
|
||||||
" torch.utils._pytree._register_pytree_node(\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n",
|
"# 1. create an RetrieveAssistantAgent instance named \"assistant\"\n",
|
||||||
"assistant = RetrieveAssistantAgent(\n",
|
"assistant = RetrieveAssistantAgent(\n",
|
||||||
|
@ -500,7 +493,7 @@
|
||||||
"# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n",
|
"# The conversation continues until the termination condition is met, in RetrieveChat, the termination condition when no human-in-loop is no code block detected.\n",
|
||||||
"# With human-in-loop, the conversation will continue until the user says \"exit\".\n",
|
"# With human-in-loop, the conversation will continue until the user says \"exit\".\n",
|
||||||
"code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n",
|
"code_problem = \"How can I use FLAML to perform a classification task and use spark to do parallel training. Train 30 seconds and force cancel jobs if time limit is reached.\"\n",
|
||||||
"ragproxyagent.initiate_chat(\n",
|
"chat_result = ragproxyagent.initiate_chat(\n",
|
||||||
" assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string=\"spark\"\n",
|
" assistant, message=ragproxyagent.message_generator, problem=code_problem, search_string=\"spark\"\n",
|
||||||
") # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain \"spark\"."
|
") # search_string is used as an extra filter for the embeddings search, in this case, we only want to search documents that contain \"spark\"."
|
||||||
]
|
]
|
||||||
|
@ -822,7 +815,7 @@
|
||||||
"assistant.reset()\n",
|
"assistant.reset()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"qa_problem = \"Who is the author of FLAML?\"\n",
|
"qa_problem = \"Who is the author of FLAML?\"\n",
|
||||||
"ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)"
|
"chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1235,7 +1228,7 @@
|
||||||
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
|
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
|
||||||
"ragproxyagent.human_input_mode = \"ALWAYS\"\n",
|
"ragproxyagent.human_input_mode = \"ALWAYS\"\n",
|
||||||
"code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n",
|
"code_problem = \"how to build a time series forecasting model for stock price using FLAML?\"\n",
|
||||||
"ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)"
|
"chat_result = ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=code_problem)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1793,7 +1786,7 @@
|
||||||
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
|
"# set `human_input_mode` to be `ALWAYS`, so the agent will ask for human input at every step.\n",
|
||||||
"ragproxyagent.human_input_mode = \"ALWAYS\"\n",
|
"ragproxyagent.human_input_mode = \"ALWAYS\"\n",
|
||||||
"qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n",
|
"qa_problem = \"Is there a function named `tune_automl` in FLAML?\"\n",
|
||||||
"ragproxyagent.initiate_chat(\n",
|
"chat_result = ragproxyagent.initiate_chat(\n",
|
||||||
" assistant, message=ragproxyagent.message_generator, problem=qa_problem\n",
|
" assistant, message=ragproxyagent.message_generator, problem=qa_problem\n",
|
||||||
") # type \"exit\" to exit the conversation"
|
") # type \"exit\" to exit the conversation"
|
||||||
]
|
]
|
||||||
|
@ -2386,7 +2379,9 @@
|
||||||
" assistant.reset()\n",
|
" assistant.reset()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" qa_problem = questions[i]\n",
|
" qa_problem = questions[i]\n",
|
||||||
" ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=30)"
|
" chat_result = ragproxyagent.initiate_chat(\n",
|
||||||
|
" assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=30\n",
|
||||||
|
" )"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2813,7 +2808,9 @@
|
||||||
" assistant.reset()\n",
|
" assistant.reset()\n",
|
||||||
"\n",
|
"\n",
|
||||||
" qa_problem = questions[i]\n",
|
" qa_problem = questions[i]\n",
|
||||||
" ragproxyagent.initiate_chat(assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=10)"
|
" chat_result = ragproxyagent.initiate_chat(\n",
|
||||||
|
" assistant, message=ragproxyagent.message_generator, problem=qa_problem, n_results=10\n",
|
||||||
|
" )"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -2839,7 +2836,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.13"
|
"version": "3.11.9"
|
||||||
},
|
},
|
||||||
"skip_test": "Requires interactive usage"
|
"skip_test": "Requires interactive usage"
|
||||||
},
|
},
|
||||||
|
|
|
@ -189,6 +189,7 @@
|
||||||
" file.write(\"\".join(file_contents))\n",
|
" file.write(\"\".join(file_contents))\n",
|
||||||
" return 0, \"Code modified\"\n",
|
" return 0, \"Code modified\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
"@user_proxy.register_for_execution()\n",
|
"@user_proxy.register_for_execution()\n",
|
||||||
"@engineer.register_for_llm(description=\"Create a new file with code.\")\n",
|
"@engineer.register_for_llm(description=\"Create a new file with code.\")\n",
|
||||||
"def create_file_with_code(\n",
|
"def create_file_with_code(\n",
|
||||||
|
|
File diff suppressed because one or more lines are too long
69
setup.py
69
setup.py
|
@ -14,7 +14,7 @@ with open(os.path.join(here, "autogen/version.py")) as fp:
|
||||||
__version__ = version["__version__"]
|
__version__ = version["__version__"]
|
||||||
|
|
||||||
install_requires = [
|
install_requires = [
|
||||||
"openai>=1.3",
|
"openai>=1.23.3",
|
||||||
"diskcache",
|
"diskcache",
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"flaml",
|
"flaml",
|
||||||
|
@ -35,7 +35,43 @@ jupyter_executor = [
|
||||||
"ipykernel>=6.29.0",
|
"ipykernel>=6.29.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
rag = ["sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"]
|
retrieve_chat = ["chromadb", "sentence_transformers", "pypdf", "ipython", "beautifulsoup4", "markdownify"]
|
||||||
|
|
||||||
|
extra_require = {
|
||||||
|
"test": [
|
||||||
|
"coverage>=5.3",
|
||||||
|
"ipykernel",
|
||||||
|
"nbconvert",
|
||||||
|
"nbformat",
|
||||||
|
"pre-commit",
|
||||||
|
"pytest-asyncio",
|
||||||
|
"pytest>=6.1.1,<8",
|
||||||
|
"pandas",
|
||||||
|
],
|
||||||
|
"blendsearch": ["flaml[blendsearch]"],
|
||||||
|
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
|
||||||
|
"retrievechat": retrieve_chat,
|
||||||
|
"retrievechat-pgvector": [
|
||||||
|
*retrieve_chat,
|
||||||
|
"pgvector>=0.2.5",
|
||||||
|
"psycopg>=3.1.18",
|
||||||
|
],
|
||||||
|
"retrievechat-qdrant": [
|
||||||
|
*retrieve_chat,
|
||||||
|
"qdrant_client[fastembed]",
|
||||||
|
],
|
||||||
|
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
|
||||||
|
"teachable": ["chromadb"],
|
||||||
|
"lmm": ["replicate", "pillow"],
|
||||||
|
"graph": ["networkx", "matplotlib"],
|
||||||
|
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
|
||||||
|
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
|
||||||
|
"redis": ["redis"],
|
||||||
|
"cosmosdb": ["azure-cosmos>=4.2.0"],
|
||||||
|
"websockets": ["websockets>=12.0,<13"],
|
||||||
|
"jupyter-executor": jupyter_executor,
|
||||||
|
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
|
||||||
|
}
|
||||||
|
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="pyautogen",
|
name="pyautogen",
|
||||||
|
@ -48,34 +84,7 @@ setuptools.setup(
|
||||||
url="https://github.com/microsoft/autogen",
|
url="https://github.com/microsoft/autogen",
|
||||||
packages=setuptools.find_packages(include=["autogen*"], exclude=["test"]),
|
packages=setuptools.find_packages(include=["autogen*"], exclude=["test"]),
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
extras_require={
|
extras_require=extra_require,
|
||||||
"test": [
|
|
||||||
"coverage>=5.3",
|
|
||||||
"ipykernel",
|
|
||||||
"nbconvert",
|
|
||||||
"nbformat",
|
|
||||||
"pre-commit",
|
|
||||||
"pytest-asyncio",
|
|
||||||
"pytest>=6.1.1,<8",
|
|
||||||
"pandas",
|
|
||||||
],
|
|
||||||
"blendsearch": ["flaml[blendsearch]"],
|
|
||||||
"mathchat": ["sympy", "pydantic==1.10.9", "wolframalpha"],
|
|
||||||
"retrievechat": ["chromadb"] + rag,
|
|
||||||
"retrievechat-pgvector": ["pgvector>=0.2.5", "psycopg>=3.1.18"] + rag,
|
|
||||||
"retrievechat-qdrant": ["qdrant_client[fastembed]"] + rag,
|
|
||||||
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub"],
|
|
||||||
"teachable": ["chromadb"],
|
|
||||||
"lmm": ["replicate", "pillow"],
|
|
||||||
"graph": ["networkx", "matplotlib"],
|
|
||||||
"gemini": ["google-generativeai>=0.5,<1", "pillow", "pydantic"],
|
|
||||||
"websurfer": ["beautifulsoup4", "markdownify", "pdfminer.six", "pathvalidate"],
|
|
||||||
"redis": ["redis"],
|
|
||||||
"cosmosdb": ["azure-cosmos>=4.2.0"],
|
|
||||||
"websockets": ["websockets>=12.0,<13"],
|
|
||||||
"jupyter-executor": jupyter_executor,
|
|
||||||
"types": ["mypy==1.9.0", "pytest>=6.1.1,<8"] + jupyter_executor,
|
|
||||||
},
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"License :: OSI Approved :: MIT License",
|
"License :: OSI Approved :: MIT License",
|
||||||
|
|
|
@ -9,10 +9,10 @@ from sentence_transformers import SentenceTransformer
|
||||||
from autogen import config_list_from_json
|
from autogen import config_list_from_json
|
||||||
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
|
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
from conftest import skip_openai # noqa: E402
|
from conftest import skip_openai # noqa: E402
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -27,14 +27,14 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
skip = False or skip_openai
|
skip = False
|
||||||
|
|
||||||
|
|
||||||
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files")
|
test_dir = os.path.join(os.path.dirname(__file__), "../../..", "test_files")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
skip,
|
skip or skip_openai,
|
||||||
reason="dependency is not installed OR requested to skip",
|
reason="dependency is not installed OR requested to skip",
|
||||||
)
|
)
|
||||||
def test_retrievechat():
|
def test_retrievechat():
|
||||||
|
@ -97,34 +97,5 @@ def test_retrievechat():
|
||||||
print(conversations)
|
print(conversations)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
skip,
|
|
||||||
reason="dependency is not installed",
|
|
||||||
)
|
|
||||||
def test_retrieve_config(caplog):
|
|
||||||
# test warning message when no docs_path is provided
|
|
||||||
ragproxyagent = RetrieveUserProxyAgent(
|
|
||||||
name="ragproxyagent",
|
|
||||||
human_input_mode="NEVER",
|
|
||||||
max_consecutive_auto_reply=2,
|
|
||||||
retrieve_config={
|
|
||||||
"chunk_token_size": 2000,
|
|
||||||
"get_or_create": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Capture the printed content
|
|
||||||
captured_logs = caplog.records[0]
|
|
||||||
print(captured_logs)
|
|
||||||
|
|
||||||
# Assert on the printed content
|
|
||||||
assert (
|
|
||||||
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
|
|
||||||
in captured_logs.message
|
|
||||||
)
|
|
||||||
assert captured_logs.levelname == "WARNING"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_retrievechat()
|
test_retrievechat()
|
||||||
test_retrieve_config()
|
|
|
@ -8,10 +8,10 @@ import pytest
|
||||||
from autogen import config_list_from_json
|
from autogen import config_list_from_json
|
||||||
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
|
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
from conftest import skip_openai # noqa: E402
|
from conftest import skip_openai # noqa: E402
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -35,7 +35,7 @@ except ImportError:
|
||||||
else:
|
else:
|
||||||
skip = False or skip_openai
|
skip = False or skip_openai
|
||||||
|
|
||||||
test_dir = os.path.join(os.path.dirname(__file__), "../..", "test_files")
|
test_dir = os.path.join(os.path.dirname(__file__), "../../..", "test_files")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
|
@ -7,10 +7,10 @@ import pytest
|
||||||
|
|
||||||
import autogen
|
import autogen
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
from conftest import skip_openai # noqa: E402
|
from conftest import reason, skip_openai # noqa: E402
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
|
||||||
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -27,12 +27,14 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
skip = False or skip_openai
|
skip = False
|
||||||
|
|
||||||
|
reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform in ["darwin", "win32"] or skip,
|
sys.platform in ["darwin", "win32"] or skip or skip_openai,
|
||||||
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
|
reason=reason,
|
||||||
)
|
)
|
||||||
def test_retrievechat():
|
def test_retrievechat():
|
||||||
conversations = {}
|
conversations = {}
|
||||||
|
@ -80,9 +82,9 @@ def test_retrievechat():
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
sys.platform in ["darwin", "win32"] or skip,
|
sys.platform in ["darwin", "win32"] or skip,
|
||||||
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
|
reason=reason,
|
||||||
)
|
)
|
||||||
def test_retrieve_config(caplog):
|
def test_retrieve_config():
|
||||||
# test warning message when no docs_path is provided
|
# test warning message when no docs_path is provided
|
||||||
ragproxyagent = RetrieveUserProxyAgent(
|
ragproxyagent = RetrieveUserProxyAgent(
|
||||||
name="ragproxyagent",
|
name="ragproxyagent",
|
||||||
|
@ -93,17 +95,7 @@ def test_retrieve_config(caplog):
|
||||||
"get_or_create": True,
|
"get_or_create": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
assert ragproxyagent._docs_path is None
|
||||||
# Capture the printed content
|
|
||||||
captured_logs = caplog.records[0]
|
|
||||||
print(captured_logs)
|
|
||||||
|
|
||||||
# Assert on the printed content
|
|
||||||
assert (
|
|
||||||
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
|
|
||||||
in captured_logs.message
|
|
||||||
)
|
|
||||||
assert captured_logs.levelname == "WARNING"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from conftest import reason
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
|
||||||
|
@ -9,28 +10,33 @@ try:
|
||||||
import pgvector
|
import pgvector
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
from autogen.agentchat.contrib.vectordb.pgvector import PGVector
|
from autogen.agentchat.contrib.vectordb.pgvectordb import PGVectorDB
|
||||||
except ImportError:
|
except ImportError:
|
||||||
skip = True
|
skip = True
|
||||||
else:
|
else:
|
||||||
skip = False
|
skip = False
|
||||||
|
|
||||||
|
reason = "do not run on MacOS or windows OR dependency is not installed OR " + reason
|
||||||
|
|
||||||
@pytest.mark.skipif(skip, reason="dependency is not installed OR requested to skip")
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.platform in ["darwin", "win32"] or skip,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
def test_pgvector():
|
def test_pgvector():
|
||||||
# test create collection
|
# test create collection
|
||||||
db_config = {
|
db_config = {
|
||||||
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
|
"connection_string": "postgresql://postgres:postgres@localhost:5432/postgres",
|
||||||
}
|
}
|
||||||
|
|
||||||
db = PGVector(connection_string=db_config["connection_string"])
|
db = PGVectorDB(connection_string=db_config["connection_string"])
|
||||||
collection_name = "test_collection"
|
collection_name = "test_collection"
|
||||||
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True)
|
collection = db.create_collection(collection_name=collection_name, overwrite=True, get_or_create=True)
|
||||||
assert collection.name == collection_name
|
assert collection.name == collection_name
|
||||||
|
|
||||||
# test_delete_collection
|
# test_delete_collection
|
||||||
db.delete_collection(collection_name)
|
db.delete_collection(collection_name)
|
||||||
pytest.raises(ValueError, db.get_collection, collection_name)
|
assert collection.table_exists(table_name=collection_name) is False
|
||||||
|
|
||||||
# test more create collection
|
# test more create collection
|
||||||
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
|
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
|
||||||
|
@ -48,21 +54,24 @@ def test_pgvector():
|
||||||
# test_insert_docs
|
# test_insert_docs
|
||||||
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
||||||
db.insert_docs(docs, collection_name, upsert=False)
|
db.insert_docs(docs, collection_name, upsert=False)
|
||||||
res = db.get_collection(collection_name).get(["1", "2"])
|
res = db.get_collection(collection_name).get(ids=["1", "2"])
|
||||||
assert res["documents"] == ["doc1", "doc2"]
|
final_results = [result.get("content") for result in res]
|
||||||
|
assert final_results == ["doc1", "doc2"]
|
||||||
|
|
||||||
# test_update_docs
|
# test_update_docs
|
||||||
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
||||||
db.update_docs(docs, collection_name)
|
db.update_docs(docs, collection_name)
|
||||||
res = db.get_collection(collection_name).get(["1", "2"])
|
res = db.get_collection(collection_name).get(["1", "2"])
|
||||||
assert res["documents"] == ["doc11", "doc2"]
|
final_results = [result.get("content") for result in res]
|
||||||
|
assert final_results == ["doc11", "doc2"]
|
||||||
|
|
||||||
# test_delete_docs
|
# test_delete_docs
|
||||||
ids = ["1"]
|
ids = ["1"]
|
||||||
collection_name = "test_collection"
|
collection_name = "test_collection"
|
||||||
db.delete_docs(ids, collection_name)
|
db.delete_docs(ids, collection_name)
|
||||||
res = db.get_collection(collection_name).get(ids)
|
res = db.get_collection(collection_name).get(ids)
|
||||||
assert res["documents"] == []
|
final_results = [result.get("content") for result in res]
|
||||||
|
assert final_results == []
|
||||||
|
|
||||||
# test_retrieve_docs
|
# test_retrieve_docs
|
||||||
queries = ["doc2", "doc3"]
|
queries = ["doc2", "doc3"]
|
||||||
|
@ -70,12 +79,13 @@ def test_pgvector():
|
||||||
res = db.retrieve_docs(queries, collection_name)
|
res = db.retrieve_docs(queries, collection_name)
|
||||||
assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
|
assert [[r[0]["id"] for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
|
||||||
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
|
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
|
||||||
print(res)
|
|
||||||
assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]]
|
assert [[r[0]["id"] for r in rr] for rr in res] == [["2"], ["3"]]
|
||||||
|
|
||||||
# test_get_docs_by_ids
|
# test_get_docs_by_ids
|
||||||
res = db.get_docs_by_ids(["1", "2"], collection_name)
|
res = db.get_docs_by_ids(["1", "2"], collection_name)
|
||||||
assert [r["id"] for r in res] == ["2"] # "1" has been deleted
|
assert [r["id"] for r in res] == ["2"] # "1" has been deleted
|
||||||
|
res = db.get_docs_by_ids(collection_name=collection_name)
|
||||||
|
assert set([r["id"] for r in res]) == set(["2", "3"]) # All Docs returned
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Loading…
Reference in New Issue