Add a warning message if docs_path not explicitly set (#814)

* Add a warning message if docs_path not explicitly set

* update

* Add how to suppress warning message

* Fix tests errors

* Fix tests errors

* Fix tests errors
This commit is contained in:
Li Jiang 2023-12-01 00:34:45 +08:00 committed by GitHub
parent f65494664d
commit ae7066be57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 37 additions and 1 deletions

View File

@ -9,6 +9,7 @@ from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code
from autogen import logger
from typing import Callable, Dict, Optional, Union, List, Tuple, Any
from IPython import get_ipython
@ -171,6 +172,12 @@ class RetrieveUserProxyAgent(UserProxyAgent):
self._client = self._retrieve_config.get("client", chromadb.Client())
self._docs_path = self._retrieve_config.get("docs_path", None)
self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs")
if "docs_path" not in self._retrieve_config:
logger.warning(
"docs_path is not provided in retrieve_config. "
f"Will raise ValueError if the collection `{self._collection_name}` doesn't exist. "
"Set docs_path to None to suppress this warning."
)
self._model = self._retrieve_config.get("model", "gpt-4")
self._max_tokens = self.get_max_tokens(self._model)
self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4))

View File

@ -68,5 +68,34 @@ def test_retrievechat():
print(conversations)
@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip_test,
reason="do not run on MacOS or windows or dependency is not installed",
)
def test_retrieve_config(caplog):
# test warning message when no docs_path is provided
ragproxyagent = RetrieveUserProxyAgent(
name="ragproxyagent",
human_input_mode="NEVER",
max_consecutive_auto_reply=2,
retrieve_config={
"chunk_token_size": 2000,
"get_or_create": True,
},
)
# Capture the printed content
captured_logs = caplog.records[0]
print(captured_logs)
# Assert on the printed content
assert (
f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist."
in captured_logs.message
)
assert captured_logs.levelname == "WARNING"
if __name__ == "__main__":
test_retrievechat()
# test_retrievechat()
test_retrieve_config()