Fix a initiate chats (#1938)

* Fix async a_initiate_chats

* Fix type compatibility for python 3.8

* Use partial func to fix context error

* website/docs/tutorial/assets/conversable-agent.jpg: convert to Git LFS

* Update notebook examples

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Davor Runje <davor@airt.ai>
This commit is contained in:
Aristo 2024-03-17 15:51:37 -07:00 committed by GitHub
parent 2cefff9206
commit 96cbaf72d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 908 additions and 611 deletions

View File

@ -1,11 +1,13 @@
import asyncio
from functools import partial
import logging
from collections import defaultdict
from collections import defaultdict, abc
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import datetime
import warnings
from termcolor import colored
from .utils import consolidate_chat_info
logger = logging.getLogger(__name__)
@ -173,6 +175,49 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
return finished_chats
def __system_now_str():
ct = datetime.datetime.now()
return f" System time at {ct}. "
def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
"""
Update ChatResult when async Task for Chat is completed.
"""
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
chat_result = chat_future.result()
chat_result.chat_id = chat_id
async def _dependent_chat_future(
chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
) -> asyncio.Task:
"""
Create an async Task for each chat.
"""
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
_chat_carryover = chat_info.get("carryover", [])
finished_chats = dict()
for chat in prerequisite_chat_futures:
chat_future = prerequisite_chat_futures[chat]
if chat_future.cancelled():
raise RuntimeError(f"Chat {chat} is cancelled.")
# wait for prerequisite chat results for the new chat carryover
finished_chats[chat] = await chat_future
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
chat_res_future.add_done_callback(call_back_with_args)
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
return chat_res_future
async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
@ -183,31 +228,25 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""
consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
finished_chat_futures = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res
pre_chat_futures = dict()
for pre_chat_id in prerequisite_chat_ids:
pre_chat_future = finished_chat_futures[pre_chat_id]
pre_chat_futures[pre_chat_id] = pre_chat_future
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
finished_chat_futures[chat_id] = current_chat_future
await asyncio.gather(*list(finished_chat_futures.values()))
finished_chats = dict()
for chat in finished_chat_futures:
chat_result = finished_chat_futures[chat].result()
finished_chats[chat] = chat_result
return finished_chats

File diff suppressed because one or more lines are too long