From ae7066be57baa51e88ce82cd5aad3df3243298ee Mon Sep 17 00:00:00 2001 From: Li Jiang Date: Fri, 1 Dec 2023 00:34:45 +0800 Subject: [PATCH] Add a warning message if docs_path not explicitly set (#814) * Add a warning message if docs_path not explicitly set * update * Add how to suppress warning message * Fix tests errors * Fix tests errors * Fix tests errors --- .../contrib/retrieve_user_proxy_agent.py | 7 +++++ test/agentchat/contrib/test_retrievechat.py | 31 ++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py index 777ac1f87c..7b8e4bf6fe 100644 --- a/autogen/agentchat/contrib/retrieve_user_proxy_agent.py +++ b/autogen/agentchat/contrib/retrieve_user_proxy_agent.py @@ -9,6 +9,7 @@ from autogen.agentchat import UserProxyAgent from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS from autogen.token_count_utils import count_token from autogen.code_utils import extract_code +from autogen import logger from typing import Callable, Dict, Optional, Union, List, Tuple, Any from IPython import get_ipython @@ -171,6 +172,12 @@ class RetrieveUserProxyAgent(UserProxyAgent): self._client = self._retrieve_config.get("client", chromadb.Client()) self._docs_path = self._retrieve_config.get("docs_path", None) self._collection_name = self._retrieve_config.get("collection_name", "autogen-docs") + if "docs_path" not in self._retrieve_config: + logger.warning( + "docs_path is not provided in retrieve_config. " + f"Will raise ValueError if the collection `{self._collection_name}` doesn't exist. " + "Set docs_path to None to suppress this warning." + ) self._model = self._retrieve_config.get("model", "gpt-4") self._max_tokens = self.get_max_tokens(self._model) self._chunk_token_size = int(self._retrieve_config.get("chunk_token_size", self._max_tokens * 0.4)) diff --git a/test/agentchat/contrib/test_retrievechat.py b/test/agentchat/contrib/test_retrievechat.py index d701ebc532..574e3571b6 100644 --- a/test/agentchat/contrib/test_retrievechat.py +++ b/test/agentchat/contrib/test_retrievechat.py @@ -68,5 +68,34 @@ def test_retrievechat(): print(conversations) +@pytest.mark.skipif( + sys.platform in ["darwin", "win32"] or skip_test, + reason="do not run on MacOS or windows or dependency is not installed", +) +def test_retrieve_config(caplog): + # test warning message when no docs_path is provided + ragproxyagent = RetrieveUserProxyAgent( + name="ragproxyagent", + human_input_mode="NEVER", + max_consecutive_auto_reply=2, + retrieve_config={ + "chunk_token_size": 2000, + "get_or_create": True, + }, + ) + + # Capture the printed content + captured_logs = caplog.records[0] + print(captured_logs) + + # Assert on the printed content + assert ( + f"docs_path is not provided in retrieve_config. Will raise ValueError if the collection `{ragproxyagent._collection_name}` doesn't exist." + in captured_logs.message + ) + assert captured_logs.levelname == "WARNING" + + if __name__ == "__main__": - test_retrievechat() + # test_retrievechat() + test_retrieve_config()