From 413a458f7527f58724338e4e5f73c91f6e0ad73e Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Tue, 28 May 2024 06:35:55 -0700 Subject: [PATCH] print next speaker (#2800) * print next speaker * fix test error --- autogen/agentchat/groupchat.py | 3 +++ test/oai/test_client.py | 35 ++++++++++++++++++---------------- test/oai/test_client_stream.py | 13 ++++++------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 83c426272..f6e5d3026 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -1051,6 +1051,9 @@ class GroupChatManager(ConversableAgent): try: # select the next speaker speaker = groupchat.select_speaker(speaker, self) + if not silent: + iostream = IOStream.get_default() + iostream.print(colored(f"\nNext speaker: {speaker.name}\n", "green"), flush=True) # let the speaker speak reply = speaker.generate_reply(sender=self) except KeyboardInterrupt: diff --git a/test/oai/test_client.py b/test/oai/test_client.py index ba2ac0fbe..1cbb43777 100755 --- a/test/oai/test_client.py +++ b/test/oai/test_client.py @@ -7,7 +7,7 @@ import time import pytest -from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai +from autogen import OpenAIWrapper, config_list_from_json from autogen.cache.cache import Cache from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED @@ -104,10 +104,11 @@ def test_chat_completion(): @pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_completion(): - config_list = config_list_openai_aoai(KEY_LOC) + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]} + ) client = OpenAIWrapper(config_list=config_list) - model = "gpt-3.5-turbo-instruct" - response = client.create(prompt="1+1=", model=model) + response = client.create(prompt="1+1=") print(response) print(client.extract_text_or_completion_object(response)) @@ -121,19 +122,21 @@ def test_completion(): ], ) def test_cost(cache_seed): - config_list = config_list_openai_aoai(KEY_LOC) - model = "gpt-3.5-turbo-instruct" + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]} + ) client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed) - response = client.create(prompt="1+3=", model=model) + response = client.create(prompt="1+3=") print(response.cost) @pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_usage_summary(): - config_list = config_list_openai_aoai(KEY_LOC) + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]} + ) client = OpenAIWrapper(config_list=config_list) - model = "gpt-3.5-turbo-instruct" - response = client.create(prompt="1+3=", model=model, cache_seed=None) + response = client.create(prompt="1+3=", cache_seed=None) # usage should be recorded assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" @@ -148,14 +151,14 @@ def test_usage_summary(): assert client.total_usage_summary is None, "total_usage_summary should be None" # actual usage and all usage should be different - response = client.create(prompt="1+3=", model=model, cache_seed=42) + response = client.create(prompt="1+3=", cache_seed=42) assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0" client.clear_usage_summary() - response = client.create(prompt="1+3=", model=model, cache_seed=42) + response = client.create(prompt="1+3=", cache_seed=42) assert client.actual_usage_summary is None, "No actual cost should be recorded" # check update - response = client.create(prompt="1+3=", model=model, cache_seed=42) + response = client.create(prompt="1+3=", cache_seed=42) assert ( client.total_usage_summary["total_cost"] == response.cost * 2 ), "total_cost should be equal to response.cost * 2" @@ -303,8 +306,8 @@ if __name__ == "__main__": # test_aoai_chat_completion() # test_oai_tool_calling_extraction() # test_chat_completion() - # test_completion() + test_completion() # # test_cost() # test_usage_summary() - test_legacy_cache() - test_cache() + # test_legacy_cache() + # test_cache() diff --git a/test/oai/test_client_stream.py b/test/oai/test_client_stream.py index 456a8fe76..1e0f3055d 100755 --- a/test/oai/test_client_stream.py +++ b/test/oai/test_client_stream.py @@ -1,14 +1,13 @@ #!/usr/bin/env python3 -m pytest -import json import os import sys -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Dict, List, Optional, Union from unittest.mock import MagicMock import pytest -from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai +from autogen import OpenAIWrapper, config_list_from_json sys.path.append(os.path.join(os.path.dirname(__file__), "..")) from conftest import skip_openai # noqa: E402 @@ -280,11 +279,11 @@ def test_chat_tools_stream() -> None: @pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_completion_stream() -> None: - config_list = config_list_openai_aoai(KEY_LOC) + config_list = config_list_from_json( + env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]} + ) client = OpenAIWrapper(config_list=config_list) - # Azure can't have dot in model/deployment name - model = "gpt-35-turbo-instruct" if config_list[0].get("api_type") == "azure" else "gpt-3.5-turbo-instruct" - response = client.create(prompt="1+1=", model=model, stream=True) + response = client.create(prompt="1+1=", stream=True) print(response) print(client.extract_text_or_completion_object(response))