From 330262b1b36cf464ec7d570eb66460a00bc2585c Mon Sep 17 00:00:00 2001 From: Joe Landers Date: Wed, 4 Sep 2024 17:55:03 -0700 Subject: [PATCH] Add Human Input Support Updates to *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes - override the `get_human_input` function and async `a_get_human_input` coroutine Updates to *WorkflowManager* classes: - add parameters `a_human_input_function` and `a_human_input_timeout` and pass along on to the ExtendedConversableAgent and ExtendedGroupChatManager - fix for invalid configuration passed from UI when human input mode is not NEVER and no model is attached Updates to *AutoGenChatManager* class: - add parameter `human_input_timeout` and pass it along to *WorkflowManager* classes - add async `a_prompt_for_input` coroutine that relies on `websocket_manager.get_input` coroutine (which snuck into last commit) Updates to *App.py* - global var HUMAN_INPUT_TIMEOUT_SECONDS = 180, we can replace this with a configurable value in the future --- .../autogenstudio/chatmanager.py | 31 +++- .../autogen-studio/autogenstudio/web/app.py | 6 +- .../autogenstudio/workflowmanager.py | 133 ++++++++++++++++++ .../frontend/src/components/atoms.tsx | 46 +++--- .../views/builder/utils/agentconfig.tsx | 4 +- .../components/views/playground/chatbox.tsx | 122 ++++++++++++---- .../components/views/playground/sessions.tsx | 14 +- .../frontend/src/hooks/store.tsx | 4 + 8 files changed, 310 insertions(+), 50 deletions(-) diff --git a/samples/apps/autogen-studio/autogenstudio/chatmanager.py b/samples/apps/autogen-studio/autogenstudio/chatmanager.py index 723e11d63..dd433ae57 100644 --- a/samples/apps/autogen-studio/autogenstudio/chatmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/chatmanager.py @@ -3,8 +3,6 @@ from datetime import datetime from queue import Queue from typing import Any, Dict, List, Optional, Tuple, Union from loguru import logger -import websockets -from fastapi import WebSocket, WebSocketDisconnect from .datamodel import Message from .workflowmanager import WorkflowManager @@ -18,7 +16,8 @@ class AutoGenChatManager: def __init__(self, message_queue: Queue, - websocket_manager: WebSocketConnectionManager = None): + websocket_manager: WebSocketConnectionManager = None, + human_input_timeout: int = 180) -> None: """ Initializes the AutoGenChatManager with a message queue. @@ -26,6 +25,7 @@ class AutoGenChatManager: """ self.message_queue = message_queue self.websocket_manager = websocket_manager + self.a_human_input_timeout = human_input_timeout def send(self, message: dict) -> None: """ @@ -53,6 +53,29 @@ class AutoGenChatManager: f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" ) + async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str: + """ + Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class + + :param message: The message string to be sent. + """ + + for connection, socket_client_id in self.websocket_manager.active_connections: + if prompt["connection_id"] == socket_client_id: + logger.info( + f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}" + ) + try: + result = await self.websocket_manager.get_input(prompt, connection, timeout) + return result + except Exception as e: + traceback.print_exc() + return f"Error: {e}\nTERMINATE" + else: + logger.info( + f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}" + ) + def chat( self, message: Message, @@ -141,6 +164,8 @@ class AutoGenChatManager: work_dir=work_dir, send_message_function=self.send, a_send_message_function=self.a_send, + a_human_input_function=self.a_prompt_for_input, + a_human_input_timeout=self.a_human_input_timeout, connection_id=connection_id, ) diff --git a/samples/apps/autogen-studio/autogenstudio/web/app.py b/samples/apps/autogen-studio/autogenstudio/web/app.py index 9db32bb36..d7db1b85a 100644 --- a/samples/apps/autogen-studio/autogenstudio/web/app.py +++ b/samples/apps/autogen-studio/autogenstudio/web/app.py @@ -4,7 +4,7 @@ import queue import threading import traceback from contextlib import asynccontextmanager -from typing import Any, Coroutine +from typing import Any, Union from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware @@ -65,12 +65,14 @@ ui_folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui") database_engine_uri = folders["database_engine_uri"] dbmanager = DBManager(engine_uri=database_engine_uri) +HUMAN_INPUT_TIMEOUT_SECONDS = 180 @asynccontextmanager async def lifespan(app: FastAPI): print("***** App started *****") managers["chat"] = AutoGenChatManager(message_queue=message_queue, - websocket_manager=websocket_manager) + websocket_manager=websocket_manager, + human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS) dbmanager.create_db_and_tables() yield diff --git a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py index 3c76fa8b3..fe3d698de 100644 --- a/samples/apps/autogen-studio/autogenstudio/workflowmanager.py +++ b/samples/apps/autogen-studio/autogenstudio/workflowmanager.py @@ -41,6 +41,8 @@ class AutoWorkflowManager: clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -53,6 +55,8 @@ class AutoWorkflowManager: clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt the user for input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -70,6 +74,8 @@ class AutoWorkflowManager: self.workflow_skills = [] self.send_message_function = send_message_function self.a_send_message_function = a_send_message_function + self.a_human_input_function = a_human_input_function + self.a_human_input_timeout = a_human_input_timeout self.connection_id = connection_id self.work_dir = work_dir or "work_dir" self.code_executor_pool = { @@ -303,6 +309,12 @@ class AutoWorkflowManager: """ """ skills = agent.get("skills", []) + + # When human input mode is not NEVER and no model is attached, the ui is passing bogus llm_config. + configured_models = agent.get("models") + if not configured_models or len(configured_models) == 0: + agent["config"]["llm_config"] = False + agent = Agent.model_validate(agent) agent.config.is_termination_msg = agent.config.is_termination_msg or ( lambda x: "TERMINATE" in x.get("content", "").rstrip()[-20:] @@ -366,6 +378,9 @@ class AutoWorkflowManager: groupchat=groupchat, message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, llm_config=agent.config.llm_config.model_dump(), ) return agent @@ -376,12 +391,18 @@ class AutoWorkflowManager: **self._serialize_agent(agent), message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, ) elif agent.type == "userproxy": agent = ExtendedConversableAgent( **self._serialize_agent(agent), message_processor=self.process_message, a_message_processor=self.a_process_message, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, + connection_id=self.connection_id, ) else: raise ValueError(f"Unknown agent type: {agent.type}") @@ -538,6 +559,8 @@ class SequentialWorkflowManager: clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -550,6 +573,8 @@ class SequentialWorkflowManager: clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt for human input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -566,6 +591,8 @@ class SequentialWorkflowManager: # TODO - improved typing for workflow self.send_message_function = send_message_function self.a_send_message_function = a_send_message_function + self.a_human_input_function = a_human_input_function + self.a_human_input_timeout = a_human_input_timeout self.connection_id = connection_id self.work_dir = work_dir or "work_dir" if clear_work_dir: @@ -617,6 +644,7 @@ class SequentialWorkflowManager: clear_work_dir=True, send_message_function=self.send_message_function, a_send_message_function=self.a_send_message_function, + a_human_input_timeout=self.a_human_input_timeout, connection_id=self.connection_id, ) task_prompt = ( @@ -679,6 +707,8 @@ class SequentialWorkflowManager: clear_work_dir=True, send_message_function=self.send_message_function, a_send_message_function=self.a_send_message_function, + a_human_input_function=self.a_human_input_function, + a_human_input_timeout=self.a_human_input_timeout, connection_id=self.connection_id, ) task_prompt = ( @@ -810,6 +840,8 @@ class WorkflowManager: clear_work_dir: bool = True, send_message_function: Optional[callable] = None, a_send_message_function: Optional[Coroutine] = None, + a_human_input_function: Optional[callable] = None, + a_human_input_timeout: Optional[int] = 60, connection_id: Optional[str] = None, ) -> None: """ @@ -822,6 +854,8 @@ class WorkflowManager: clear_work_dir (bool): If set to True, clears the working directory. send_message_function (Optional[callable]): The function to send messages. a_send_message_function (Optional[Coroutine]): Async coroutine to send messages. + a_human_input_function (Optional[callable]): Async coroutine to prompt for user input. + a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation. connection_id (Optional[str]): The connection identifier. """ if isinstance(workflow, str): @@ -843,6 +877,8 @@ class WorkflowManager: clear_work_dir=clear_work_dir, send_message_function=send_message_function, a_send_message_function=a_send_message_function, + a_human_input_function=a_human_input_function, + a_human_input_timeout=a_human_input_timeout, connection_id=connection_id, ) elif self.workflow.get("type") == WorkFlowType.sequential.value: @@ -852,6 +888,9 @@ class WorkflowManager: work_dir=work_dir, clear_work_dir=clear_work_dir, send_message_function=send_message_function, + a_send_message_function=a_send_message_function, + a_human_input_function=a_human_input_function, + a_human_input_timeout=a_human_input_timeout, connection_id=connection_id, ) @@ -860,11 +899,18 @@ class ExtendedConversableAgent(autogen.ConversableAgent): def __init__(self, message_processor=None, a_message_processor=None, + a_human_input_function=None, + a_human_input_timeout: Optional[int] = 60, + connection_id=None, *args, **kwargs): super().__init__(*args, **kwargs) self.message_processor = message_processor self.a_message_processor = a_message_processor + self.a_human_input_function = a_human_input_function + self.a_human_input_response = None + self.a_human_input_timeout = a_human_input_timeout + self.connection_id = connection_id def receive( self, @@ -891,14 +937,65 @@ class ExtendedConversableAgent(autogen.ConversableAgent): await super().a_receive(message, sender, request_reply, silent) + # Strangely, when the response from a_get_human_input == "" (empty string) the libs call into the + # sync version. I guess that's "just in case", but it's odd because replying with an empty string + # is the intended way for the user to signal the underlying libs that they want to system to go forward + # with whatever funciton call, tool call or AI genrated response the request calls for. Oh well, + # Que Sera Sera. + def get_human_input(self, prompt: str) -> str: + if self.a_human_input_response == None: + return super().get_human_input(prompt) + else: + response = self.a_human_input_response + self.a_human_input_response = None + return response + + async def a_get_human_input(self, prompt: str) -> str: + if self.message_processor and self.a_human_input_function: + message_dict = { + "content": prompt, + "role": "system", + "type": "user-input-request" + } + + message_payload = { + "recipient": self.name, + "sender": "system", + "message": message_dict, + "timestamp": datetime.now().isoformat(), + "sender_type": "system", + "connection_id": self.connection_id, + "message_type": "agent_message" + } + + socket_msg = SocketMessage( + type="user_input_request", + data=message_payload, + connection_id=self.connection_id, + ) + self.a_human_input_response = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout) + return self.a_human_input_response + + else: + result = await super().a_get_human_input(prompt) + return result + + class ExtendedGroupChatManager(autogen.GroupChatManager): def __init__(self, message_processor=None, a_message_processor=None, + a_human_input_function=None, + a_human_input_timeout: Optional[int] = 60, + connection_id=None, *args, **kwargs): super().__init__(*args, **kwargs) self.message_processor = message_processor self.a_message_processor = a_message_processor + self.a_human_input_function = a_human_input_function + self.a_human_input_response = None + self.a_human_input_timeout = a_human_input_timeout + self.connection_id = connection_id def receive( self, @@ -925,3 +1022,39 @@ class ExtendedGroupChatManager(autogen.GroupChatManager): await super().a_receive(message, sender, request_reply, silent) + def get_human_input(self, prompt: str) -> str: + if self.a_human_input_response == None: + return super().get_human_input(prompt) + else: + response = self.a_human_input_response + self.a_human_input_response = None + return response + + async def a_get_human_input(self, prompt: str) -> str: + if self.message_processor and self.a_human_input_function: + message_dict = { + "content": prompt, + "role": "system", + "type": "user-input-request" + } + + message_payload = { + "recipient": self.name, + "sender": "system", + "message": message_dict, + "timestamp": datetime.now().isoformat(), + "sender_type": "system", + "connection_id": self.connection_id, + "message_type": "agent_message" + } + socket_msg = SocketMessage( + type="user_input_request", + data=message_payload, + connection_id=self.connection_id, + ) + result = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout) + return result + + else: + result = await super().a_get_human_input(prompt) + return result diff --git a/samples/apps/autogen-studio/frontend/src/components/atoms.tsx b/samples/apps/autogen-studio/frontend/src/components/atoms.tsx index a0864153f..8f52e6028 100644 --- a/samples/apps/autogen-studio/frontend/src/components/atoms.tsx +++ b/samples/apps/autogen-studio/frontend/src/components/atoms.tsx @@ -49,7 +49,7 @@ export const SectionHeader = ({ icon, }: IProps) => { return ( -
+

{/* {count !== null && {count}} */} {icon && <>{icon}} @@ -72,6 +72,7 @@ export const IconButton = ({ }: IProps) => { return ( { return (