initate chats enhancement (#2404)

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Shobhit Vishnoi 2024-04-18 11:50:48 +05:30 committed by GitHub
parent d307818dd9
commit 8033fc6228
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 7 deletions

View File

@ -171,6 +171,9 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
- `"carryover"` - It can be used to specify the carryover information to be passed - `"carryover"` - It can be used to specify the carryover information to be passed
to this chat. If provided, we will combine this carryover with the "message" content when to this chat. If provided, we will combine this carryover with the "message" content when
generating the initial chat message in `generate_init_message`. generating the initial chat message in `generate_init_message`.
- `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list,
from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list,
then summary from all the finished chats will be taken.
Returns: Returns:
(list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue.
""" """
@ -182,9 +185,16 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
while current_chat_queue: while current_chat_queue:
chat_info = current_chat_queue.pop(0) chat_info = current_chat_queue.pop(0)
_chat_carryover = chat_info.get("carryover", []) _chat_carryover = chat_info.get("carryover", [])
finished_chat_indexes_to_exclude_from_carryover = chat_info.get(
"finished_chat_indexes_to_exclude_from_carryover", []
)
if isinstance(_chat_carryover, str): if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover] _chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats] chat_info["carryover"] = _chat_carryover + [
r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover
]
__post_carryover_processing(chat_info) __post_carryover_processing(chat_info)
sender = chat_info["sender"] sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info) chat_res = sender.initiate_chat(**chat_info)

View File

@ -1180,6 +1180,23 @@ class ConversableAgent(LLMAgent):
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response return response
def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Check the chat queue and add the "sender" key if it's missing.
Args:
chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information.
Returns:
List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing.
"""
chat_queue_with_sender = []
for chat_info in chat_queue:
if chat_info.get("sender") is None:
chat_info["sender"] = self
chat_queue_with_sender.append(chat_info)
return chat_queue_with_sender
def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""(Experimental) Initiate chats with multiple agents. """(Experimental) Initiate chats with multiple agents.
@ -1189,16 +1206,13 @@ class ConversableAgent(LLMAgent):
Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
""" """
_chat_queue = chat_queue.copy() _chat_queue = self._check_chat_queue_for_sender(chat_queue)
for chat_info in _chat_queue:
chat_info["sender"] = self
self._finished_chats = initiate_chats(_chat_queue) self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats return self._finished_chats
async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
_chat_queue = chat_queue.copy()
for chat_info in _chat_queue: _chat_queue = self._check_chat_queue_for_sender(chat_queue)
chat_info["sender"] = self
self._finished_chats = await a_initiate_chats(_chat_queue) self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats return self._finished_chats