mirror of https://github.com/microsoft/autogen.git
Add Human Input Support
Updates to *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes - override the `get_human_input` function and async `a_get_human_input` coroutine Updates to *WorkflowManager* classes: - add parameters `a_human_input_function` and `a_human_input_timeout` and pass along on to the ExtendedConversableAgent and ExtendedGroupChatManager - fix for invalid configuration passed from UI when human input mode is not NEVER and no model is attached Updates to *AutoGenChatManager* class: - add parameter `human_input_timeout` and pass it along to *WorkflowManager* classes - add async `a_prompt_for_input` coroutine that relies on `websocket_manager.get_input` coroutine (which snuck into last commit) Updates to *App.py* - global var HUMAN_INPUT_TIMEOUT_SECONDS = 180, we can replace this with a configurable value in the future
This commit is contained in:
parent
a6624d8d04
commit
330262b1b3
|
@ -3,8 +3,6 @@ from datetime import datetime
|
|||
from queue import Queue
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from loguru import logger
|
||||
import websockets
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from .datamodel import Message
|
||||
from .workflowmanager import WorkflowManager
|
||||
|
@ -18,7 +16,8 @@ class AutoGenChatManager:
|
|||
|
||||
def __init__(self,
|
||||
message_queue: Queue,
|
||||
websocket_manager: WebSocketConnectionManager = None):
|
||||
websocket_manager: WebSocketConnectionManager = None,
|
||||
human_input_timeout: int = 180) -> None:
|
||||
"""
|
||||
Initializes the AutoGenChatManager with a message queue.
|
||||
|
||||
|
@ -26,6 +25,7 @@ class AutoGenChatManager:
|
|||
"""
|
||||
self.message_queue = message_queue
|
||||
self.websocket_manager = websocket_manager
|
||||
self.a_human_input_timeout = human_input_timeout
|
||||
|
||||
def send(self, message: dict) -> None:
|
||||
"""
|
||||
|
@ -53,6 +53,29 @@ class AutoGenChatManager:
|
|||
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
|
||||
)
|
||||
|
||||
async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str:
|
||||
"""
|
||||
Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class
|
||||
|
||||
:param message: The message string to be sent.
|
||||
"""
|
||||
|
||||
for connection, socket_client_id in self.websocket_manager.active_connections:
|
||||
if prompt["connection_id"] == socket_client_id:
|
||||
logger.info(
|
||||
f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
|
||||
)
|
||||
try:
|
||||
result = await self.websocket_manager.get_input(prompt, connection, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return f"Error: {e}\nTERMINATE"
|
||||
else:
|
||||
logger.info(
|
||||
f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
|
||||
)
|
||||
|
||||
def chat(
|
||||
self,
|
||||
message: Message,
|
||||
|
@ -141,6 +164,8 @@ class AutoGenChatManager:
|
|||
work_dir=work_dir,
|
||||
send_message_function=self.send,
|
||||
a_send_message_function=self.a_send,
|
||||
a_human_input_function=self.a_prompt_for_input,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import queue
|
|||
import threading
|
||||
import traceback
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Coroutine
|
||||
from typing import Any, Union
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -65,12 +65,14 @@ ui_folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui")
|
|||
database_engine_uri = folders["database_engine_uri"]
|
||||
dbmanager = DBManager(engine_uri=database_engine_uri)
|
||||
|
||||
HUMAN_INPUT_TIMEOUT_SECONDS = 180
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
print("***** App started *****")
|
||||
managers["chat"] = AutoGenChatManager(message_queue=message_queue,
|
||||
websocket_manager=websocket_manager)
|
||||
websocket_manager=websocket_manager,
|
||||
human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS)
|
||||
dbmanager.create_db_and_tables()
|
||||
|
||||
yield
|
||||
|
|
|
@ -41,6 +41,8 @@ class AutoWorkflowManager:
|
|||
clear_work_dir: bool = True,
|
||||
send_message_function: Optional[callable] = None,
|
||||
a_send_message_function: Optional[Coroutine] = None,
|
||||
a_human_input_function: Optional[callable] = None,
|
||||
a_human_input_timeout: Optional[int] = 60,
|
||||
connection_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -53,6 +55,8 @@ class AutoWorkflowManager:
|
|||
clear_work_dir (bool): If set to True, clears the working directory.
|
||||
send_message_function (Optional[callable]): The function to send messages.
|
||||
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
|
||||
a_human_input_function (Optional[callable]): Async coroutine to prompt the user for input.
|
||||
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
|
||||
connection_id (Optional[str]): The connection identifier.
|
||||
"""
|
||||
if isinstance(workflow, str):
|
||||
|
@ -70,6 +74,8 @@ class AutoWorkflowManager:
|
|||
self.workflow_skills = []
|
||||
self.send_message_function = send_message_function
|
||||
self.a_send_message_function = a_send_message_function
|
||||
self.a_human_input_function = a_human_input_function
|
||||
self.a_human_input_timeout = a_human_input_timeout
|
||||
self.connection_id = connection_id
|
||||
self.work_dir = work_dir or "work_dir"
|
||||
self.code_executor_pool = {
|
||||
|
@ -303,6 +309,12 @@ class AutoWorkflowManager:
|
|||
""" """
|
||||
|
||||
skills = agent.get("skills", [])
|
||||
|
||||
# When human input mode is not NEVER and no model is attached, the ui is passing bogus llm_config.
|
||||
configured_models = agent.get("models")
|
||||
if not configured_models or len(configured_models) == 0:
|
||||
agent["config"]["llm_config"] = False
|
||||
|
||||
agent = Agent.model_validate(agent)
|
||||
agent.config.is_termination_msg = agent.config.is_termination_msg or (
|
||||
lambda x: "TERMINATE" in x.get("content", "").rstrip()[-20:]
|
||||
|
@ -366,6 +378,9 @@ class AutoWorkflowManager:
|
|||
groupchat=groupchat,
|
||||
message_processor=self.process_message,
|
||||
a_message_processor=self.a_process_message,
|
||||
a_human_input_function=self.a_human_input_function,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=self.connection_id,
|
||||
llm_config=agent.config.llm_config.model_dump(),
|
||||
)
|
||||
return agent
|
||||
|
@ -376,12 +391,18 @@ class AutoWorkflowManager:
|
|||
**self._serialize_agent(agent),
|
||||
message_processor=self.process_message,
|
||||
a_message_processor=self.a_process_message,
|
||||
a_human_input_function=self.a_human_input_function,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
elif agent.type == "userproxy":
|
||||
agent = ExtendedConversableAgent(
|
||||
**self._serialize_agent(agent),
|
||||
message_processor=self.process_message,
|
||||
a_message_processor=self.a_process_message,
|
||||
a_human_input_function=self.a_human_input_function,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown agent type: {agent.type}")
|
||||
|
@ -538,6 +559,8 @@ class SequentialWorkflowManager:
|
|||
clear_work_dir: bool = True,
|
||||
send_message_function: Optional[callable] = None,
|
||||
a_send_message_function: Optional[Coroutine] = None,
|
||||
a_human_input_function: Optional[callable] = None,
|
||||
a_human_input_timeout: Optional[int] = 60,
|
||||
connection_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -550,6 +573,8 @@ class SequentialWorkflowManager:
|
|||
clear_work_dir (bool): If set to True, clears the working directory.
|
||||
send_message_function (Optional[callable]): The function to send messages.
|
||||
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
|
||||
a_human_input_function (Optional[callable]): Async coroutine to prompt for human input.
|
||||
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
|
||||
connection_id (Optional[str]): The connection identifier.
|
||||
"""
|
||||
if isinstance(workflow, str):
|
||||
|
@ -566,6 +591,8 @@ class SequentialWorkflowManager:
|
|||
# TODO - improved typing for workflow
|
||||
self.send_message_function = send_message_function
|
||||
self.a_send_message_function = a_send_message_function
|
||||
self.a_human_input_function = a_human_input_function
|
||||
self.a_human_input_timeout = a_human_input_timeout
|
||||
self.connection_id = connection_id
|
||||
self.work_dir = work_dir or "work_dir"
|
||||
if clear_work_dir:
|
||||
|
@ -617,6 +644,7 @@ class SequentialWorkflowManager:
|
|||
clear_work_dir=True,
|
||||
send_message_function=self.send_message_function,
|
||||
a_send_message_function=self.a_send_message_function,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
task_prompt = (
|
||||
|
@ -679,6 +707,8 @@ class SequentialWorkflowManager:
|
|||
clear_work_dir=True,
|
||||
send_message_function=self.send_message_function,
|
||||
a_send_message_function=self.a_send_message_function,
|
||||
a_human_input_function=self.a_human_input_function,
|
||||
a_human_input_timeout=self.a_human_input_timeout,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
task_prompt = (
|
||||
|
@ -810,6 +840,8 @@ class WorkflowManager:
|
|||
clear_work_dir: bool = True,
|
||||
send_message_function: Optional[callable] = None,
|
||||
a_send_message_function: Optional[Coroutine] = None,
|
||||
a_human_input_function: Optional[callable] = None,
|
||||
a_human_input_timeout: Optional[int] = 60,
|
||||
connection_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
|
@ -822,6 +854,8 @@ class WorkflowManager:
|
|||
clear_work_dir (bool): If set to True, clears the working directory.
|
||||
send_message_function (Optional[callable]): The function to send messages.
|
||||
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
|
||||
a_human_input_function (Optional[callable]): Async coroutine to prompt for user input.
|
||||
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
|
||||
connection_id (Optional[str]): The connection identifier.
|
||||
"""
|
||||
if isinstance(workflow, str):
|
||||
|
@ -843,6 +877,8 @@ class WorkflowManager:
|
|||
clear_work_dir=clear_work_dir,
|
||||
send_message_function=send_message_function,
|
||||
a_send_message_function=a_send_message_function,
|
||||
a_human_input_function=a_human_input_function,
|
||||
a_human_input_timeout=a_human_input_timeout,
|
||||
connection_id=connection_id,
|
||||
)
|
||||
elif self.workflow.get("type") == WorkFlowType.sequential.value:
|
||||
|
@ -852,6 +888,9 @@ class WorkflowManager:
|
|||
work_dir=work_dir,
|
||||
clear_work_dir=clear_work_dir,
|
||||
send_message_function=send_message_function,
|
||||
a_send_message_function=a_send_message_function,
|
||||
a_human_input_function=a_human_input_function,
|
||||
a_human_input_timeout=a_human_input_timeout,
|
||||
connection_id=connection_id,
|
||||
)
|
||||
|
||||
|
@ -860,11 +899,18 @@ class ExtendedConversableAgent(autogen.ConversableAgent):
|
|||
def __init__(self,
|
||||
message_processor=None,
|
||||
a_message_processor=None,
|
||||
a_human_input_function=None,
|
||||
a_human_input_timeout: Optional[int] = 60,
|
||||
connection_id=None,
|
||||
*args, **kwargs):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.message_processor = message_processor
|
||||
self.a_message_processor = a_message_processor
|
||||
self.a_human_input_function = a_human_input_function
|
||||
self.a_human_input_response = None
|
||||
self.a_human_input_timeout = a_human_input_timeout
|
||||
self.connection_id = connection_id
|
||||
|
||||
def receive(
|
||||
self,
|
||||
|
@ -891,14 +937,65 @@ class ExtendedConversableAgent(autogen.ConversableAgent):
|
|||
await super().a_receive(message, sender, request_reply, silent)
|
||||
|
||||
|
||||
# Strangely, when the response from a_get_human_input == "" (empty string) the libs call into the
|
||||
# sync version. I guess that's "just in case", but it's odd because replying with an empty string
|
||||
# is the intended way for the user to signal the underlying libs that they want to system to go forward
|
||||
# with whatever funciton call, tool call or AI genrated response the request calls for. Oh well,
|
||||
# Que Sera Sera.
|
||||
def get_human_input(self, prompt: str) -> str:
|
||||
if self.a_human_input_response == None:
|
||||
return super().get_human_input(prompt)
|
||||
else:
|
||||
response = self.a_human_input_response
|
||||
self.a_human_input_response = None
|
||||
return response
|
||||
|
||||
async def a_get_human_input(self, prompt: str) -> str:
|
||||
if self.message_processor and self.a_human_input_function:
|
||||
message_dict = {
|
||||
"content": prompt,
|
||||
"role": "system",
|
||||
"type": "user-input-request"
|
||||
}
|
||||
|
||||
message_payload = {
|
||||
"recipient": self.name,
|
||||
"sender": "system",
|
||||
"message": message_dict,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sender_type": "system",
|
||||
"connection_id": self.connection_id,
|
||||
"message_type": "agent_message"
|
||||
}
|
||||
|
||||
socket_msg = SocketMessage(
|
||||
type="user_input_request",
|
||||
data=message_payload,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
self.a_human_input_response = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout)
|
||||
return self.a_human_input_response
|
||||
|
||||
else:
|
||||
result = await super().a_get_human_input(prompt)
|
||||
return result
|
||||
|
||||
|
||||
class ExtendedGroupChatManager(autogen.GroupChatManager):
|
||||
def __init__(self,
|
||||
message_processor=None,
|
||||
a_message_processor=None,
|
||||
a_human_input_function=None,
|
||||
a_human_input_timeout: Optional[int] = 60,
|
||||
connection_id=None,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.message_processor = message_processor
|
||||
self.a_message_processor = a_message_processor
|
||||
self.a_human_input_function = a_human_input_function
|
||||
self.a_human_input_response = None
|
||||
self.a_human_input_timeout = a_human_input_timeout
|
||||
self.connection_id = connection_id
|
||||
|
||||
def receive(
|
||||
self,
|
||||
|
@ -925,3 +1022,39 @@ class ExtendedGroupChatManager(autogen.GroupChatManager):
|
|||
await super().a_receive(message, sender, request_reply, silent)
|
||||
|
||||
|
||||
def get_human_input(self, prompt: str) -> str:
|
||||
if self.a_human_input_response == None:
|
||||
return super().get_human_input(prompt)
|
||||
else:
|
||||
response = self.a_human_input_response
|
||||
self.a_human_input_response = None
|
||||
return response
|
||||
|
||||
async def a_get_human_input(self, prompt: str) -> str:
|
||||
if self.message_processor and self.a_human_input_function:
|
||||
message_dict = {
|
||||
"content": prompt,
|
||||
"role": "system",
|
||||
"type": "user-input-request"
|
||||
}
|
||||
|
||||
message_payload = {
|
||||
"recipient": self.name,
|
||||
"sender": "system",
|
||||
"message": message_dict,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"sender_type": "system",
|
||||
"connection_id": self.connection_id,
|
||||
"message_type": "agent_message"
|
||||
}
|
||||
socket_msg = SocketMessage(
|
||||
type="user_input_request",
|
||||
data=message_payload,
|
||||
connection_id=self.connection_id,
|
||||
)
|
||||
result = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout)
|
||||
return result
|
||||
|
||||
else:
|
||||
result = await super().a_get_human_input(prompt)
|
||||
return result
|
||||
|
|
|
@ -49,7 +49,7 @@ export const SectionHeader = ({
|
|||
icon,
|
||||
}: IProps) => {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<div id="section-header" className="mb-4">
|
||||
<h1 className="text-primary text-2xl">
|
||||
{/* {count !== null && <span className="text-accent mr-1">{count}</span>} */}
|
||||
{icon && <>{icon}</>}
|
||||
|
@ -72,6 +72,7 @@ export const IconButton = ({
|
|||
}: IProps) => {
|
||||
return (
|
||||
<span
|
||||
id="icon-button"
|
||||
role={"button"}
|
||||
onClick={onClick}
|
||||
className={`inline-block mr-2 hover:text-accent transition duration-300 ${className} ${
|
||||
|
@ -90,6 +91,7 @@ export const LaunchButton = ({
|
|||
}: any) => {
|
||||
return (
|
||||
<button
|
||||
id="launch-button"
|
||||
role={"button"}
|
||||
className={` focus:ring ring-accent ring-l-none rounded cursor-pointer hover:brightness-110 bg-accent transition duration-500 text-white ${className} `}
|
||||
onClick={onClick}
|
||||
|
@ -102,6 +104,7 @@ export const LaunchButton = ({
|
|||
export const SecondaryButton = ({ children, onClick, className }: any) => {
|
||||
return (
|
||||
<button
|
||||
id="secondary-button"
|
||||
role={"button"}
|
||||
className={` ${className} focus:ring ring-accent p-2 px-5 rounded cursor-pointer hover:brightness-90 bg-secondary transition duration-500 text-primary`}
|
||||
onClick={onClick}
|
||||
|
@ -128,6 +131,7 @@ export const Card = ({
|
|||
|
||||
return (
|
||||
<button
|
||||
id="card"
|
||||
tabIndex={0}
|
||||
onClick={onClick}
|
||||
role={"button"}
|
||||
|
@ -157,6 +161,7 @@ export const CollapseBox = ({
|
|||
const chevronClass = "h-4 cursor-pointer inline-block mr-1";
|
||||
return (
|
||||
<div
|
||||
id="collapse-box"
|
||||
onMouseDown={(e) => {
|
||||
if (e.detail > 1) {
|
||||
e.preventDefault();
|
||||
|
@ -192,7 +197,7 @@ export const CollapseBox = ({
|
|||
};
|
||||
|
||||
export const HighLight = ({ children }: IProps) => {
|
||||
return <span className="border-b border-accent">{children}</span>;
|
||||
return <span id="highlight" className="border-b border-accent">{children}</span>;
|
||||
};
|
||||
|
||||
export const LoadBox = ({
|
||||
|
@ -200,8 +205,7 @@ export const LoadBox = ({
|
|||
className = "my-2 text-accent ",
|
||||
}: IProps) => {
|
||||
return (
|
||||
<div className={`${className} `}>
|
||||
{" "}
|
||||
<div id="load-box" className={`${className} `}>
|
||||
<span className="mr-2 ">
|
||||
{" "}
|
||||
<Icon size={5} icon="loading" />
|
||||
|
@ -214,7 +218,7 @@ export const LoadBox = ({
|
|||
export const LoadingBar = ({ children }: IProps) => {
|
||||
return (
|
||||
<>
|
||||
<div className="rounded bg-secondary p-3">
|
||||
<div id="loading-bar" className="rounded bg-secondary p-3">
|
||||
<span className="inline-block h-6 w-6 relative mr-2">
|
||||
<Cog8ToothIcon className="animate-ping text-accent absolute inline-flex h-full w-full rounded-ful opacity-75" />
|
||||
<Cog8ToothIcon className="relative text-accent animate-spin inline-flex rounded-full h-6 w-6" />
|
||||
|
@ -239,6 +243,7 @@ export const MessageBox = ({ title, children, className }: IProps) => {
|
|||
|
||||
return (
|
||||
<div
|
||||
id="message-box"
|
||||
ref={messageBox}
|
||||
className={`${className} p-3 rounded bg-secondary transition duration-1000 ease-in-out overflow-hidden`}
|
||||
>
|
||||
|
@ -272,7 +277,7 @@ export const GroupView = ({
|
|||
className = "text-primary bg-primary ",
|
||||
}: any) => {
|
||||
return (
|
||||
<div className={`rounded mt-4 border-secondary ${className}`}>
|
||||
<div id="group-view" className={`rounded mt-4 border-secondary ${className}`}>
|
||||
<div className="mt-4 p-2 rounded border relative">
|
||||
<div className={`absolute -top-3 inline-block ${className}`}>
|
||||
{title}
|
||||
|
@ -297,6 +302,7 @@ export const ExpandView = ({
|
|||
const minImageWidth = 400;
|
||||
return (
|
||||
<div
|
||||
id="expand-view"
|
||||
style={{
|
||||
minHeight: "100px",
|
||||
}}
|
||||
|
@ -347,6 +353,7 @@ export const LoadingOverlay = ({ children, loading }: IProps) => {
|
|||
{loading && (
|
||||
<>
|
||||
<div
|
||||
id="loading-overlay"
|
||||
className="absolute inset-0 bg-secondary flex pointer-events-none"
|
||||
style={{ opacity: 0.5 }}
|
||||
>
|
||||
|
@ -376,7 +383,11 @@ export const MarkdownView = ({
|
|||
showCode?: boolean;
|
||||
}) => {
|
||||
function processString(inputString: string): string {
|
||||
inputString = inputString.replace(/\n/g, " \n");
|
||||
// TODO: Had to add this temp measure while debugging. Why is it null?
|
||||
if (!inputString) {
|
||||
console.log("inputString is null!")
|
||||
}
|
||||
inputString = inputString && inputString.replace(/\n/g, " \n");
|
||||
const markdownPattern = /```markdown\s+([\s\S]*?)\s+```/g;
|
||||
return inputString?.replace(markdownPattern, (match, content) => content);
|
||||
}
|
||||
|
@ -449,6 +460,7 @@ export const MarkdownView = ({
|
|||
|
||||
return (
|
||||
<div
|
||||
id="markdown-view"
|
||||
className={` w-full chatbox prose dark:prose-invert text-primary rounded ${className}`}
|
||||
>
|
||||
<ReactMarkdown
|
||||
|
@ -499,7 +511,7 @@ export const CodeBlock = ({
|
|||
|
||||
const [showCopied, setShowCopied] = React.useState(false);
|
||||
return (
|
||||
<div className="relative">
|
||||
<div id="code-block" className="relative">
|
||||
<div className=" rounded absolute right-5 top-4 z-10 ">
|
||||
<div className="relative border border-transparent w-full h-full">
|
||||
<div
|
||||
|
@ -566,7 +578,7 @@ export const ControlRowView = ({
|
|||
truncateLength?: number;
|
||||
}) => {
|
||||
return (
|
||||
<div className={`${className}`}>
|
||||
<div id="control-row-view" className={`${className}`}>
|
||||
<div>
|
||||
<span className="text-primary inline-block">{title} </span>
|
||||
<span className="text-xs ml-1 text-accent -mt-2 inline-block">
|
||||
|
@ -590,7 +602,7 @@ export const BounceLoader = ({
|
|||
title?: string;
|
||||
}) => {
|
||||
return (
|
||||
<div className="inline-block">
|
||||
<div id="bounce-loader" className="inline-block">
|
||||
<div className="inline-flex gap-2">
|
||||
<span className=" rounded-full bg-accent h-2 w-2 inline-block"></span>
|
||||
<span className="animate-bounce rounded-full bg-accent h-3 w-3 inline-block"></span>
|
||||
|
@ -611,10 +623,10 @@ export const ImageLoader = ({
|
|||
const [isLoading, setIsLoading] = useState(true);
|
||||
|
||||
return (
|
||||
<div className="w-full rounded relative">
|
||||
<div id="image-loader" className="w-full rounded relative">
|
||||
{isLoading && (
|
||||
<div className="absolute h-24 inset-0 flex items-center justify-center">
|
||||
<BounceLoader title=" loading .." />{" "}
|
||||
<BounceLoader title=" loading .." />
|
||||
</div>
|
||||
)}
|
||||
<img
|
||||
|
@ -685,7 +697,7 @@ export const CsvLoader = ({
|
|||
const scrollX = columns.length * 150;
|
||||
|
||||
return (
|
||||
<div className={`CsvLoader ${className}`}>
|
||||
<div id="csv-loader" className={`CsvLoader ${className}`}>
|
||||
<Table
|
||||
dataSource={data}
|
||||
columns={columns}
|
||||
|
@ -720,7 +732,7 @@ export const CodeLoader = ({
|
|||
}, [url]);
|
||||
|
||||
return (
|
||||
<div className={`w-full rounded relative ${className}`}>
|
||||
<div id="code-loader" className={`w-full rounded relative ${className}`}>
|
||||
{isLoading && (
|
||||
<div className="absolute h-24 inset-0 flex items-center justify-center">
|
||||
<BounceLoader />
|
||||
|
@ -743,7 +755,7 @@ export const PdfViewer = ({ url }: { url: string }) => {
|
|||
|
||||
// Render the PDF viewer
|
||||
return (
|
||||
<div className="h-full">
|
||||
<div id="pdf-viewer" className="h-full">
|
||||
{loading && <p>Loading PDF...</p>}
|
||||
{!loading && (
|
||||
<object
|
||||
|
@ -779,7 +791,7 @@ export const MonacoEditor = ({
|
|||
setIsEditorReady(true);
|
||||
};
|
||||
return (
|
||||
<div className="h-full rounded">
|
||||
<div id="monaco-editor" className="h-full rounded">
|
||||
<Editor
|
||||
height="100%"
|
||||
className="h-full rounded"
|
||||
|
@ -820,6 +832,7 @@ export const CardHoverBar = ({
|
|||
return (
|
||||
<div
|
||||
key={"cardhoverrow" + i}
|
||||
id={`card-hover-bar-item-${i}`}
|
||||
role="button"
|
||||
className="text-accent text-xs inline-block hover:bg-primary p-2 rounded"
|
||||
onClick={item.onClick}
|
||||
|
@ -832,6 +845,7 @@ export const CardHoverBar = ({
|
|||
});
|
||||
return (
|
||||
<div
|
||||
id="card-hover-bar"
|
||||
onMouseEnter={(e) => {
|
||||
e.stopPropagation();
|
||||
}}
|
||||
|
|
|
@ -193,8 +193,8 @@ export const AgentConfigView = ({
|
|||
options={
|
||||
[
|
||||
{ label: "NEVER", value: "NEVER" },
|
||||
// { label: "TERMINATE", value: "TERMINATE" },
|
||||
// { label: "ALWAYS", value: "ALWAYS" },
|
||||
{ label: "TERMINATE", value: "TERMINATE" },
|
||||
{ label: "ALWAYS", value: "ALWAYS" },
|
||||
] as any
|
||||
}
|
||||
/>
|
||||
|
|
|
@ -62,6 +62,10 @@ const ChatBox = ({
|
|||
const [workflow, setWorkflow] = React.useState<IWorkflow | null>(null);
|
||||
|
||||
const [socketMessages, setSocketMessages] = React.useState<any[]>([]);
|
||||
const [awaitingUserInput, setAwaitingUserInput] = React.useState(false); // New state for tracking user input
|
||||
const setAreSessionButtonsDisabled = useConfigStore(
|
||||
(state) => state.setAreSessionButtonsDisabled
|
||||
);
|
||||
|
||||
const MAX_RETRIES = 10;
|
||||
const RETRY_INTERVAL = 2000;
|
||||
|
@ -102,7 +106,7 @@ const ChatBox = ({
|
|||
try {
|
||||
meta = JSON.parse(message.meta);
|
||||
} catch (e) {
|
||||
meta = message.meta;
|
||||
meta = message?.meta;
|
||||
}
|
||||
const msg: IChatMessage = {
|
||||
text: message.content,
|
||||
|
@ -122,7 +126,6 @@ const ChatBox = ({
|
|||
const initMsgs: IChatMessage[] = parseMessages(initMessages);
|
||||
setMessages(initMsgs);
|
||||
wsMessages.current = initMsgs;
|
||||
socketMsgs = [];
|
||||
}, [initMessages]);
|
||||
|
||||
const promptButtons = examplePrompts.map((prompt, i) => {
|
||||
|
@ -141,7 +144,7 @@ const ChatBox = ({
|
|||
);
|
||||
});
|
||||
|
||||
const messageListView = messages?.map((message: IChatMessage, i: number) => {
|
||||
const messageListView = messages && messages?.map((message: IChatMessage, i: number) => {
|
||||
const isUser = message.sender === "user";
|
||||
const css = isUser ? "bg-accent text-white " : "bg-light";
|
||||
// console.log("message", message);
|
||||
|
@ -209,7 +212,7 @@ const ChatBox = ({
|
|||
|
||||
return (
|
||||
<div
|
||||
className={`align-right ${isUser ? "text-righpt" : ""} mb-2 border-b`}
|
||||
id={"message" + i} className={`align-right ${isUser ? "text-righpt" : ""} mb-2 border-b`}
|
||||
key={"message" + i}
|
||||
>
|
||||
{" "}
|
||||
|
@ -294,10 +297,10 @@ const ChatBox = ({
|
|||
}, 500);
|
||||
}, [messages]);
|
||||
|
||||
const textAreaDefaultHeight = "50px";
|
||||
const textAreaDefaultHeight = "64px";
|
||||
// clear text box if loading has just changed to false and there is no error
|
||||
React.useEffect(() => {
|
||||
if (loading === false && textAreaInputRef.current) {
|
||||
if ((awaitingUserInput || loading === false) && textAreaInputRef.current) {
|
||||
if (textAreaInputRef.current) {
|
||||
if (error === null || (error && error.status === false)) {
|
||||
textAreaInputRef.current.value = "";
|
||||
|
@ -372,6 +375,19 @@ const ChatBox = ({
|
|||
scrollChatBox(messageBoxInputRef);
|
||||
}, 200);
|
||||
// console.log("received message", data, socketMsgs.length);
|
||||
} else if (data && data.type === "user_input_request") {
|
||||
setAwaitingUserInput(true); // Set awaiting input state
|
||||
textAreaInputRef.current.value = ""
|
||||
textAreaInputRef.current.placeholder = data.data.message.content
|
||||
const newsocketMessages = Object.assign([], socketMessages);
|
||||
newsocketMessages.push(data.data);
|
||||
setSocketMessages(newsocketMessages);
|
||||
socketMsgs.push(data.data);
|
||||
setTimeout(() => {
|
||||
scrollChatBox(socketDivRef);
|
||||
scrollChatBox(messageBoxInputRef);
|
||||
}, 200);
|
||||
ToastMessage.info(data.data.message)
|
||||
} else if (data && data.type === "agent_status") {
|
||||
// indicates a status message update
|
||||
const agentStatusSpan = document.getElementById("agentstatusspan");
|
||||
|
@ -380,6 +396,8 @@ const ChatBox = ({
|
|||
}
|
||||
} else if (data && data.type === "agent_response") {
|
||||
// indicates a final agent response
|
||||
setAwaitingUserInput(false); // Set awaiting input state
|
||||
setAreSessionButtonsDisabled(false);
|
||||
processAgentResponse(data.data);
|
||||
}
|
||||
};
|
||||
|
@ -408,11 +426,13 @@ const ChatBox = ({
|
|||
wsMessages.current.push(msg);
|
||||
setMessages(wsMessages.current);
|
||||
setLoading(false);
|
||||
setAwaitingUserInput(false);
|
||||
} else {
|
||||
console.log("error", data);
|
||||
// setError(data);
|
||||
ToastMessage.error(data.message);
|
||||
setLoading(false);
|
||||
setAwaitingUserInput(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -442,6 +462,7 @@ const ChatBox = ({
|
|||
|
||||
const runWorkflow = (query: string) => {
|
||||
setError(null);
|
||||
setAreSessionButtonsDisabled(true);
|
||||
socketMsgs = [];
|
||||
let messageHolder = Object.assign([], messages);
|
||||
|
||||
|
@ -519,6 +540,43 @@ const ChatBox = ({
|
|||
}
|
||||
};
|
||||
|
||||
const sendUserResponse = (userResponse: string) => {
|
||||
setAwaitingUserInput(false);
|
||||
setError(null);
|
||||
setLoading(true);
|
||||
|
||||
textAreaInputRef.current.placeholder = "Write message here..."
|
||||
|
||||
const userMessage: IChatMessage = {
|
||||
text: userResponse,
|
||||
sender: "system",
|
||||
};
|
||||
|
||||
const messagePayload: IMessage = {
|
||||
role: "user",
|
||||
content: userResponse,
|
||||
user_id: user?.email || "",
|
||||
session_id: session?.id,
|
||||
workflow_id: session?.workflow_id,
|
||||
connection_id: connectionId,
|
||||
};
|
||||
|
||||
// check if socket connected,
|
||||
if (wsClient.current && wsClient.current.readyState === 1) {
|
||||
wsClient.current.send(
|
||||
JSON.stringify({
|
||||
connection_id: connectionId,
|
||||
data: messagePayload,
|
||||
type: "user_message",
|
||||
session_id: session?.id,
|
||||
workflow_id: session?.workflow_id,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
console.err("websocket client error")
|
||||
}
|
||||
};
|
||||
|
||||
const handleTextChange = (
|
||||
event: React.ChangeEvent<HTMLTextAreaElement>
|
||||
): void => {
|
||||
|
@ -529,9 +587,14 @@ const ChatBox = ({
|
|||
event: React.KeyboardEvent<HTMLTextAreaElement>
|
||||
): void => {
|
||||
if (event.key === "Enter" && !event.shiftKey) {
|
||||
if (textAreaInputRef.current && !loading) {
|
||||
if (textAreaInputRef.current &&(awaitingUserInput || !loading)) {
|
||||
event.preventDefault();
|
||||
runWorkflow(textAreaInputRef.current.value);
|
||||
if (awaitingUserInput) {
|
||||
sendUserResponse(textAreaInputRef.current.value); // New function call for sending user input
|
||||
textAreaInputRef.current.value = "";
|
||||
} else {
|
||||
runWorkflow(textAreaInputRef.current.value);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -554,7 +617,7 @@ const ChatBox = ({
|
|||
|
||||
const WorkflowView = ({ workflow }: { workflow: IWorkflow }) => {
|
||||
return (
|
||||
<div className="text-xs cursor-pointer inline-block">
|
||||
<div id="workflow-view" className="text-xs cursor-pointer inline-block">
|
||||
{" "}
|
||||
{workflow.name}
|
||||
</div>
|
||||
|
@ -563,11 +626,13 @@ const ChatBox = ({
|
|||
|
||||
return (
|
||||
<div
|
||||
id="chatbox-main"
|
||||
style={{ height: "calc(100vh - " + heightOffset + "px)" }}
|
||||
className="text-primary relative rounded "
|
||||
ref={mainDivRef}
|
||||
>
|
||||
<div
|
||||
id="workflow-name"
|
||||
style={{ zIndex: 100 }}
|
||||
className=" absolute right-3 bg-primary rounded text-secondary -top-6 p-2"
|
||||
>
|
||||
|
@ -576,17 +641,18 @@ const ChatBox = ({
|
|||
</div>
|
||||
|
||||
<div
|
||||
id="message-box"
|
||||
ref={messageBoxInputRef}
|
||||
className="flex h-full flex-col rounded scroll pr-2 overflow-auto "
|
||||
style={{ minHeight: "30px", height: "calc(100vh - 310px)" }}
|
||||
>
|
||||
<div className="scroll-gradient h-10">
|
||||
<div id="scroll-gradient" className="scroll-gradient h-10">
|
||||
{" "}
|
||||
<span className=" inline-block h-6"></span>{" "}
|
||||
</div>
|
||||
<div className="flex-1 boder mt-4"></div>
|
||||
{!messages && messages !== null && (
|
||||
<div className="w-full text-center boder mt-4">
|
||||
<div id="loading-messages" className="w-full text-center boder mt-4">
|
||||
<div>
|
||||
{" "}
|
||||
<BounceLoader />
|
||||
|
@ -596,15 +662,15 @@ const ChatBox = ({
|
|||
)}
|
||||
|
||||
{messages && messages?.length === 0 && (
|
||||
<div className="ml-2 text-sm text-secondary ">
|
||||
<div id="no-messages" className="ml-2 text-sm text-secondary ">
|
||||
<InformationCircleIcon className="inline-block h-6 mr-2" />
|
||||
No messages in the current session. Start a conversation to begin.
|
||||
</div>
|
||||
)}
|
||||
<div className="ml-2"> {messageListView}</div>
|
||||
|
||||
{loading && (
|
||||
<div className={` inline-flex gap-2 duration-300 `}>
|
||||
<div id="message-list" className="ml-2"> {messageListView}</div>
|
||||
{(loading || awaitingUserInput) && (
|
||||
<div id="loading-bar" className={` inline-flex gap-2 duration-300 `}>
|
||||
<div className=""></div>
|
||||
<div className="font-semibold text-secondary text-sm w-16">
|
||||
AGENTS
|
||||
|
@ -632,6 +698,7 @@ const ChatBox = ({
|
|||
|
||||
{socketMsgs.length > 0 && (
|
||||
<div
|
||||
id="agent-messages"
|
||||
ref={socketDivRef}
|
||||
style={{
|
||||
minHeight: "300px",
|
||||
|
@ -661,10 +728,11 @@ const ChatBox = ({
|
|||
)}
|
||||
</div>
|
||||
{editable && (
|
||||
<div className="mt-2 p-2 absolute bg-primary bottom-0 w-full">
|
||||
<div id="input-area" className="mt-2 p-2 absolute bg-primary bottom-0 w-full">
|
||||
<div
|
||||
id="input-form"
|
||||
className={`rounded p-2 shadow-lg flex mb-1 gap-2 ${
|
||||
loading ? " opacity-50 pointer-events-none" : ""
|
||||
loading && !awaitingUserInput ? " opacity-50 pointer-events-none" : ""
|
||||
}`}
|
||||
>
|
||||
{/* <input className="flex-1 p-2 ring-2" /> */}
|
||||
|
@ -683,7 +751,7 @@ const ChatBox = ({
|
|||
onChange={handleTextChange}
|
||||
placeholder="Write message here..."
|
||||
ref={textAreaInputRef}
|
||||
className="flex items-center w-full resize-none text-gray-600 bg-white p-2 ring-2 rounded-sm pl-5 pr-16"
|
||||
className="flex items-center w-full resize-none text-gray-600 bg-white p-2 ring-2 rounded-sm pl-5 pr-16 h-64"
|
||||
style={{
|
||||
maxHeight: "120px",
|
||||
overflowY: "auto",
|
||||
|
@ -691,23 +759,28 @@ const ChatBox = ({
|
|||
}}
|
||||
/>
|
||||
<div
|
||||
id="send-button"
|
||||
role={"button"}
|
||||
style={{ width: "45px", height: "35px" }}
|
||||
title="Send message"
|
||||
onClick={() => {
|
||||
if (textAreaInputRef.current && !loading) {
|
||||
runWorkflow(textAreaInputRef.current.value);
|
||||
if (textAreaInputRef.current && (awaitingUserInput || !loading)) {
|
||||
if (awaitingUserInput) {
|
||||
sendUserResponse(textAreaInputRef.current.value); // Use the new function for user input
|
||||
} else {
|
||||
runWorkflow(textAreaInputRef.current.value);
|
||||
}
|
||||
}
|
||||
}}
|
||||
className="absolute right-3 bottom-2 bg-accent hover:brightness-75 transition duration-300 rounded cursor-pointer flex justify-center items-center"
|
||||
>
|
||||
{" "}
|
||||
{!loading && (
|
||||
{(awaitingUserInput || !loading) && (
|
||||
<div className="inline-block ">
|
||||
<PaperAirplaneIcon className="h-6 w-6 text-white " />{" "}
|
||||
</div>
|
||||
)}
|
||||
{loading && (
|
||||
{loading && !awaitingUserInput && (
|
||||
<div className="inline-block ">
|
||||
<Cog6ToothIcon className="text-white animate-spin rounded-full h-6 w-6" />
|
||||
</div>
|
||||
|
@ -728,15 +801,16 @@ const ChatBox = ({
|
|||
</div>
|
||||
|
||||
<div
|
||||
id="prompt-buttons"
|
||||
className={`mt-2 inline-flex gap-2 flex-wrap ${
|
||||
loading ? "brightness-75 pointer-events-none" : ""
|
||||
(loading && !awaitingUserInput) ? "brightness-75 pointer-events-none" : ""
|
||||
}`}
|
||||
>
|
||||
{promptButtons}
|
||||
</div>
|
||||
</div>
|
||||
{error && !error.status && (
|
||||
<div className="p-2 rounded mt-4 text-orange-500 text-sm">
|
||||
<div id="error-message" className="p-2 rounded mt-4 text-orange-500 text-sm">
|
||||
{" "}
|
||||
<ExclamationTriangleIcon className="h-5 text-orange-500 inline-block mr-2" />{" "}
|
||||
{error.message}
|
||||
|
|
|
@ -51,6 +51,10 @@ const SessionsView = ({}: any) => {
|
|||
const session = useConfigStore((state) => state.session);
|
||||
const setSession = useConfigStore((state) => state.setSession);
|
||||
|
||||
const isSessionButtonsDisabled = useConfigStore(
|
||||
(state) => state.areSessionButtonsDisabled
|
||||
);
|
||||
|
||||
const deleteSession = (session: IChatSession) => {
|
||||
setError(null);
|
||||
setLoading(true);
|
||||
|
@ -284,7 +288,9 @@ const SessionsView = ({}: any) => {
|
|||
return (
|
||||
<div
|
||||
key={"sessionsrow" + index}
|
||||
className="group relative mb-2 pb-1 border-b border-dashed "
|
||||
className={`group relative mb-2 pb-1 border-b border-dashed ${
|
||||
isSessionButtonsDisabled ? "opacity-50 pointer-events-none" : ""
|
||||
}`}
|
||||
>
|
||||
{items.length > 0 && (
|
||||
<div className=" absolute right-2 top-2 group-hover:opacity-100 opacity-0 ">
|
||||
|
@ -295,8 +301,10 @@ const SessionsView = ({}: any) => {
|
|||
className={`rounded p-2 cursor-pointer ${rowClass}`}
|
||||
role="button"
|
||||
onClick={() => {
|
||||
setSession(data);
|
||||
// setWorkflowConfig(data.flow_config);
|
||||
if (!isSessionButtonsDisabled) {
|
||||
setSession(data);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="text-xs mt-1">
|
||||
|
@ -424,7 +432,7 @@ const SessionsView = ({}: any) => {
|
|||
<div className="flex gap-x-2">
|
||||
<div className="flex-1"></div>
|
||||
<LaunchButton
|
||||
className="text-sm p-2 px-3"
|
||||
className={`text-sm p-2 px-3 ${isSessionButtonsDisabled ? 'opacity-50 cursor-not-allowed' : ''}`}
|
||||
onClick={() => {
|
||||
setSelectedSession(sampleSession);
|
||||
setNewSessionModalVisible(true);
|
||||
|
|
|
@ -14,6 +14,8 @@ interface ConfigState {
|
|||
setVersion: (version: string) => void;
|
||||
connectionId: string;
|
||||
setConnectionId: (connectionId: string) => void;
|
||||
areSessionButtonsDisabled: boolean;
|
||||
setAreSessionButtonsDisabled: (disabled: boolean) => void;
|
||||
}
|
||||
|
||||
export const useConfigStore = create<ConfigState>()((set) => ({
|
||||
|
@ -27,4 +29,6 @@ export const useConfigStore = create<ConfigState>()((set) => ({
|
|||
setVersion: (version) => set({ version }),
|
||||
connectionId: uuidv4(),
|
||||
setConnectionId: (connectionId) => set({ connectionId }),
|
||||
areSessionButtonsDisabled: false,
|
||||
setAreSessionButtonsDisabled: (disabled) => set({ areSessionButtonsDisabled: disabled }),
|
||||
}));
|
||||
|
|
Loading…
Reference in New Issue