This commit is contained in:
Umer Mansoor 2024-10-23 13:44:14 -04:00 committed by GitHub
commit 78f4f3c8f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 12 deletions

View File

@ -257,17 +257,36 @@ class ConversableAgent(LLMAgent):
}
def _validate_llm_config(self, llm_config):
assert llm_config in (None, False) or isinstance(
llm_config, dict
), "llm_config must be a dict or False or None."
assert llm_config in (None, False) or isinstance(llm_config, dict), "llm_config must be a dict, False, or None."
if llm_config is None:
llm_config = self.DEFAULT_CONFIG
self.llm_config = self.DEFAULT_CONFIG if llm_config is None else llm_config
# TODO: more complete validity check
if self.llm_config in [{}, {"config_list": []}, {"config_list": [{"model": ""}]}]:
self.llm_config = llm_config
if isinstance(self.llm_config, dict):
config_list = self.llm_config.get("config_list", [])
if not isinstance(config_list, list):
raise ValueError("llm_config: 'config_list' must be a list.")
# If config_list is empty, check if 'model' field is present in llm_config
if not config_list and "model" not in self.llm_config:
raise ValueError("llm_config: 'model' field is required in 'llm_config' or each 'config_list' entry.")
# When config_list is not empty, check each item in config_list
for config in config_list:
if not isinstance(config, dict):
raise ValueError("llm_config: 'config_list' must be a list of dictionaries.")
if "model" not in config or not config["model"]:
raise ValueError(
"When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'."
"llm_config: 'model' field is required for each item in 'config_list' and must not be empty."
)
if "api_key" in config and not isinstance(config["api_key"], str):
raise ValueError("llm_config: 'api_key' must be a string.")
if "tags" in config:
if not isinstance(config["tags"], list) or not all(isinstance(tag, str) for tag in config["tags"]):
raise ValueError("llm_config: 'tags' must be a list of strings.")
self.client = None if self.llm_config is False else OpenAIWrapper(**self.llm_config)
@staticmethod

View File

@ -831,10 +831,28 @@ def test_register_for_llm_without_LLM():
return f"{s} done"
def test_register_for_llm_with_valid_configuration():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
valid_config = {
"config_list": [
{"model": "gpt-4"},
{"model": "gpt-4", "api_key": MOCK_OPEN_AI_API_KEY, "tags": ["gpt4", "openai"]},
],
}
assistant = ConversableAgent(
name="assistant",
llm_config=valid_config,
)
assert assistant is not None
assert assistant.llm_config == valid_config
def test_register_for_llm_without_configuration():
with pytest.raises(
ValueError,
match="When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'.",
match="llm_config: 'model' field is required in 'llm_config' or each 'config_list' entry.",
):
ConversableAgent(name="agent", llm_config={"config_list": []})
@ -842,11 +860,27 @@ def test_register_for_llm_without_configuration():
def test_register_for_llm_without_model_name():
with pytest.raises(
ValueError,
match="When using OpenAI or Azure OpenAI endpoints, specify a non-empty 'model' either in 'llm_config' or in each config of 'config_list'.",
match="llm_config: 'model' field is required for each item in 'config_list' and must not be empty.",
):
ConversableAgent(name="agent", llm_config={"config_list": [{"model": ""}]})
def test_register_for_llm_with_invalid_tags():
with pytest.raises(
ValueError,
match="llm_config: 'tags' must be a list of strings.",
):
ConversableAgent(name="agent", llm_config={"config_list": [{"model": "gpt-4", "tags": "invalid_tags"}]})
def test_register_for_llm_with_invalid_api_key():
with pytest.raises(
ValueError,
match="llm_config: 'api_key' must be a string.",
):
ConversableAgent(name="agent", llm_config={"config_list": [{"model": "gpt-4", "api_key": 1234}]})
def test_register_for_execution():
with pytest.MonkeyPatch.context() as mp:
mp.setenv("OPENAI_API_KEY", MOCK_OPEN_AI_API_KEY)
@ -1606,7 +1640,7 @@ def test_http_client():
def test_adding_duplicate_function_warning():
config_base = [{"base_url": "http://0.0.0.0:8000", "api_key": "NULL"}]
config_base = [{"base_url": "http://0.0.0.0:8000", "api_key": "NULL", "model": "na"}]
agent = autogen.ConversableAgent(
"jtoy",