mirror of https://github.com/microsoft/autogen.git
Fix a initiate chats (#1938)
* Fix async a_initiate_chats * Fix type compatibility for python 3.8 * Use partial func to fix context error * website/docs/tutorial/assets/conversable-agent.jpg: convert to Git LFS * Update notebook examples --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Davor Runje <davor@airt.ai>
This commit is contained in:
parent
2cefff9206
commit
96cbaf72d3
|
@ -1,11 +1,13 @@
|
|||
import asyncio
|
||||
from functools import partial
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, abc
|
||||
from typing import Dict, List, Any, Set, Tuple
|
||||
from dataclasses import dataclass
|
||||
from .utils import consolidate_chat_info
|
||||
import datetime
|
||||
import warnings
|
||||
from termcolor import colored
|
||||
from .utils import consolidate_chat_info
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -173,6 +175,49 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
|
|||
return finished_chats
|
||||
|
||||
|
||||
def __system_now_str():
|
||||
ct = datetime.datetime.now()
|
||||
return f" System time at {ct}. "
|
||||
|
||||
|
||||
def _on_chat_future_done(chat_future: asyncio.Future, chat_id: int):
|
||||
"""
|
||||
Update ChatResult when async Task for Chat is completed.
|
||||
"""
|
||||
logger.debug(f"Update chat {chat_id} result on task completion." + __system_now_str())
|
||||
chat_result = chat_future.result()
|
||||
chat_result.chat_id = chat_id
|
||||
|
||||
|
||||
async def _dependent_chat_future(
|
||||
chat_id: int, chat_info: Dict[str, Any], prerequisite_chat_futures: Dict[int, asyncio.Future]
|
||||
) -> asyncio.Task:
|
||||
"""
|
||||
Create an async Task for each chat.
|
||||
"""
|
||||
logger.debug(f"Create Task for chat {chat_id}." + __system_now_str())
|
||||
_chat_carryover = chat_info.get("carryover", [])
|
||||
finished_chats = dict()
|
||||
for chat in prerequisite_chat_futures:
|
||||
chat_future = prerequisite_chat_futures[chat]
|
||||
if chat_future.cancelled():
|
||||
raise RuntimeError(f"Chat {chat} is cancelled.")
|
||||
|
||||
# wait for prerequisite chat results for the new chat carryover
|
||||
finished_chats[chat] = await chat_future
|
||||
|
||||
if isinstance(_chat_carryover, str):
|
||||
_chat_carryover = [_chat_carryover]
|
||||
chat_info["carryover"] = _chat_carryover + [finished_chats[pre_id].summary for pre_id in finished_chats]
|
||||
__post_carryover_processing(chat_info)
|
||||
sender = chat_info["sender"]
|
||||
chat_res_future = asyncio.create_task(sender.a_initiate_chat(**chat_info))
|
||||
call_back_with_args = partial(_on_chat_future_done, chat_id=chat_id)
|
||||
chat_res_future.add_done_callback(call_back_with_args)
|
||||
logger.debug(f"Task for chat {chat_id} created." + __system_now_str())
|
||||
return chat_res_future
|
||||
|
||||
|
||||
async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
|
||||
"""(async) Initiate a list of chats.
|
||||
|
||||
|
@ -183,31 +228,25 @@ async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatRe
|
|||
returns:
|
||||
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
|
||||
"""
|
||||
|
||||
consolidate_chat_info(chat_queue)
|
||||
_validate_recipients(chat_queue)
|
||||
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
|
||||
num_chats = chat_book.keys()
|
||||
prerequisites = __create_async_prerequisites(chat_queue)
|
||||
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
|
||||
finished_chats = dict()
|
||||
finished_chat_futures = dict()
|
||||
for chat_id in chat_order_by_id:
|
||||
chat_info = chat_book[chat_id]
|
||||
condition = asyncio.Condition()
|
||||
prerequisite_chat_ids = chat_info.get("prerequisites", [])
|
||||
async with condition:
|
||||
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
|
||||
# Do the actual work here.
|
||||
_chat_carryover = chat_info.get("carryover", [])
|
||||
if isinstance(_chat_carryover, str):
|
||||
_chat_carryover = [_chat_carryover]
|
||||
chat_info["carryover"] = _chat_carryover + [
|
||||
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
|
||||
]
|
||||
__post_carryover_processing(chat_info)
|
||||
sender = chat_info["sender"]
|
||||
chat_res = await sender.a_initiate_chat(**chat_info)
|
||||
chat_res.chat_id = chat_id
|
||||
finished_chats[chat_id] = chat_res
|
||||
|
||||
pre_chat_futures = dict()
|
||||
for pre_chat_id in prerequisite_chat_ids:
|
||||
pre_chat_future = finished_chat_futures[pre_chat_id]
|
||||
pre_chat_futures[pre_chat_id] = pre_chat_future
|
||||
current_chat_future = await _dependent_chat_future(chat_id, chat_info, pre_chat_futures)
|
||||
finished_chat_futures[chat_id] = current_chat_future
|
||||
await asyncio.gather(*list(finished_chat_futures.values()))
|
||||
finished_chats = dict()
|
||||
for chat in finished_chat_futures:
|
||||
chat_result = finished_chat_futures[chat].result()
|
||||
finished_chats[chat] = chat_result
|
||||
return finished_chats
|
||||
|
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue