Add Human Input Support

Updates to *ExtendedConversableAgent* and *ExtendedGroupChatManager* classes
- override the `get_human_input` function and async `a_get_human_input` coroutine

Updates to *WorkflowManager* classes:
- add parameters `a_human_input_function` and `a_human_input_timeout` and pass along on to the ExtendedConversableAgent and ExtendedGroupChatManager
- fix for invalid configuration passed from UI when human input mode is not NEVER and no model is attached

Updates to *AutoGenChatManager* class:
- add parameter `human_input_timeout` and pass it along to *WorkflowManager* classes
- add async `a_prompt_for_input` coroutine that relies on `websocket_manager.get_input` coroutine (which snuck into last commit)

Updates to *App.py*
- global var HUMAN_INPUT_TIMEOUT_SECONDS = 180, we can replace this with a configurable value in the future
This commit is contained in:
Joe Landers 2024-09-04 17:55:03 -07:00
parent a6624d8d04
commit 330262b1b3
8 changed files with 310 additions and 50 deletions

View File

@ -3,8 +3,6 @@ from datetime import datetime
from queue import Queue
from typing import Any, Dict, List, Optional, Tuple, Union
from loguru import logger
import websockets
from fastapi import WebSocket, WebSocketDisconnect
from .datamodel import Message
from .workflowmanager import WorkflowManager
@ -18,7 +16,8 @@ class AutoGenChatManager:
def __init__(self,
message_queue: Queue,
websocket_manager: WebSocketConnectionManager = None):
websocket_manager: WebSocketConnectionManager = None,
human_input_timeout: int = 180) -> None:
"""
Initializes the AutoGenChatManager with a message queue.
@ -26,6 +25,7 @@ class AutoGenChatManager:
"""
self.message_queue = message_queue
self.websocket_manager = websocket_manager
self.a_human_input_timeout = human_input_timeout
def send(self, message: dict) -> None:
"""
@ -53,6 +53,29 @@ class AutoGenChatManager:
f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}"
)
async def a_prompt_for_input(self, prompt: dict, timeout: int = 60) -> str:
"""
Sends the user a prompt and waits for a response asynchronously via the WebSocketManager class
:param message: The message string to be sent.
"""
for connection, socket_client_id in self.websocket_manager.active_connections:
if prompt["connection_id"] == socket_client_id:
logger.info(
f"Sending message to connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)
try:
result = await self.websocket_manager.get_input(prompt, connection, timeout)
return result
except Exception as e:
traceback.print_exc()
return f"Error: {e}\nTERMINATE"
else:
logger.info(
f"Skipping message for connection_id: {prompt['connection_id']}. Connection ID: {socket_client_id}"
)
def chat(
self,
message: Message,
@ -141,6 +164,8 @@ class AutoGenChatManager:
work_dir=work_dir,
send_message_function=self.send,
a_send_message_function=self.a_send,
a_human_input_function=self.a_prompt_for_input,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=connection_id,
)

View File

@ -4,7 +4,7 @@ import queue
import threading
import traceback
from contextlib import asynccontextmanager
from typing import Any, Coroutine
from typing import Any, Union
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
@ -65,12 +65,14 @@ ui_folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui")
database_engine_uri = folders["database_engine_uri"]
dbmanager = DBManager(engine_uri=database_engine_uri)
HUMAN_INPUT_TIMEOUT_SECONDS = 180
@asynccontextmanager
async def lifespan(app: FastAPI):
print("***** App started *****")
managers["chat"] = AutoGenChatManager(message_queue=message_queue,
websocket_manager=websocket_manager)
websocket_manager=websocket_manager,
human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS)
dbmanager.create_db_and_tables()
yield

View File

@ -41,6 +41,8 @@ class AutoWorkflowManager:
clear_work_dir: bool = True,
send_message_function: Optional[callable] = None,
a_send_message_function: Optional[Coroutine] = None,
a_human_input_function: Optional[callable] = None,
a_human_input_timeout: Optional[int] = 60,
connection_id: Optional[str] = None,
) -> None:
"""
@ -53,6 +55,8 @@ class AutoWorkflowManager:
clear_work_dir (bool): If set to True, clears the working directory.
send_message_function (Optional[callable]): The function to send messages.
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
a_human_input_function (Optional[callable]): Async coroutine to prompt the user for input.
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
connection_id (Optional[str]): The connection identifier.
"""
if isinstance(workflow, str):
@ -70,6 +74,8 @@ class AutoWorkflowManager:
self.workflow_skills = []
self.send_message_function = send_message_function
self.a_send_message_function = a_send_message_function
self.a_human_input_function = a_human_input_function
self.a_human_input_timeout = a_human_input_timeout
self.connection_id = connection_id
self.work_dir = work_dir or "work_dir"
self.code_executor_pool = {
@ -303,6 +309,12 @@ class AutoWorkflowManager:
""" """
skills = agent.get("skills", [])
# When human input mode is not NEVER and no model is attached, the ui is passing bogus llm_config.
configured_models = agent.get("models")
if not configured_models or len(configured_models) == 0:
agent["config"]["llm_config"] = False
agent = Agent.model_validate(agent)
agent.config.is_termination_msg = agent.config.is_termination_msg or (
lambda x: "TERMINATE" in x.get("content", "").rstrip()[-20:]
@ -366,6 +378,9 @@ class AutoWorkflowManager:
groupchat=groupchat,
message_processor=self.process_message,
a_message_processor=self.a_process_message,
a_human_input_function=self.a_human_input_function,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=self.connection_id,
llm_config=agent.config.llm_config.model_dump(),
)
return agent
@ -376,12 +391,18 @@ class AutoWorkflowManager:
**self._serialize_agent(agent),
message_processor=self.process_message,
a_message_processor=self.a_process_message,
a_human_input_function=self.a_human_input_function,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=self.connection_id,
)
elif agent.type == "userproxy":
agent = ExtendedConversableAgent(
**self._serialize_agent(agent),
message_processor=self.process_message,
a_message_processor=self.a_process_message,
a_human_input_function=self.a_human_input_function,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=self.connection_id,
)
else:
raise ValueError(f"Unknown agent type: {agent.type}")
@ -538,6 +559,8 @@ class SequentialWorkflowManager:
clear_work_dir: bool = True,
send_message_function: Optional[callable] = None,
a_send_message_function: Optional[Coroutine] = None,
a_human_input_function: Optional[callable] = None,
a_human_input_timeout: Optional[int] = 60,
connection_id: Optional[str] = None,
) -> None:
"""
@ -550,6 +573,8 @@ class SequentialWorkflowManager:
clear_work_dir (bool): If set to True, clears the working directory.
send_message_function (Optional[callable]): The function to send messages.
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
a_human_input_function (Optional[callable]): Async coroutine to prompt for human input.
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
connection_id (Optional[str]): The connection identifier.
"""
if isinstance(workflow, str):
@ -566,6 +591,8 @@ class SequentialWorkflowManager:
# TODO - improved typing for workflow
self.send_message_function = send_message_function
self.a_send_message_function = a_send_message_function
self.a_human_input_function = a_human_input_function
self.a_human_input_timeout = a_human_input_timeout
self.connection_id = connection_id
self.work_dir = work_dir or "work_dir"
if clear_work_dir:
@ -617,6 +644,7 @@ class SequentialWorkflowManager:
clear_work_dir=True,
send_message_function=self.send_message_function,
a_send_message_function=self.a_send_message_function,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=self.connection_id,
)
task_prompt = (
@ -679,6 +707,8 @@ class SequentialWorkflowManager:
clear_work_dir=True,
send_message_function=self.send_message_function,
a_send_message_function=self.a_send_message_function,
a_human_input_function=self.a_human_input_function,
a_human_input_timeout=self.a_human_input_timeout,
connection_id=self.connection_id,
)
task_prompt = (
@ -810,6 +840,8 @@ class WorkflowManager:
clear_work_dir: bool = True,
send_message_function: Optional[callable] = None,
a_send_message_function: Optional[Coroutine] = None,
a_human_input_function: Optional[callable] = None,
a_human_input_timeout: Optional[int] = 60,
connection_id: Optional[str] = None,
) -> None:
"""
@ -822,6 +854,8 @@ class WorkflowManager:
clear_work_dir (bool): If set to True, clears the working directory.
send_message_function (Optional[callable]): The function to send messages.
a_send_message_function (Optional[Coroutine]): Async coroutine to send messages.
a_human_input_function (Optional[callable]): Async coroutine to prompt for user input.
a_human_input_timeout (Optional[int]): A time (in seconds) to wait for user input. After this time, the a_human_input_function will timeout and end the conversation.
connection_id (Optional[str]): The connection identifier.
"""
if isinstance(workflow, str):
@ -843,6 +877,8 @@ class WorkflowManager:
clear_work_dir=clear_work_dir,
send_message_function=send_message_function,
a_send_message_function=a_send_message_function,
a_human_input_function=a_human_input_function,
a_human_input_timeout=a_human_input_timeout,
connection_id=connection_id,
)
elif self.workflow.get("type") == WorkFlowType.sequential.value:
@ -852,6 +888,9 @@ class WorkflowManager:
work_dir=work_dir,
clear_work_dir=clear_work_dir,
send_message_function=send_message_function,
a_send_message_function=a_send_message_function,
a_human_input_function=a_human_input_function,
a_human_input_timeout=a_human_input_timeout,
connection_id=connection_id,
)
@ -860,11 +899,18 @@ class ExtendedConversableAgent(autogen.ConversableAgent):
def __init__(self,
message_processor=None,
a_message_processor=None,
a_human_input_function=None,
a_human_input_timeout: Optional[int] = 60,
connection_id=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.message_processor = message_processor
self.a_message_processor = a_message_processor
self.a_human_input_function = a_human_input_function
self.a_human_input_response = None
self.a_human_input_timeout = a_human_input_timeout
self.connection_id = connection_id
def receive(
self,
@ -891,14 +937,65 @@ class ExtendedConversableAgent(autogen.ConversableAgent):
await super().a_receive(message, sender, request_reply, silent)
# Strangely, when the response from a_get_human_input == "" (empty string) the libs call into the
# sync version. I guess that's "just in case", but it's odd because replying with an empty string
# is the intended way for the user to signal the underlying libs that they want to system to go forward
# with whatever funciton call, tool call or AI genrated response the request calls for. Oh well,
# Que Sera Sera.
def get_human_input(self, prompt: str) -> str:
if self.a_human_input_response == None:
return super().get_human_input(prompt)
else:
response = self.a_human_input_response
self.a_human_input_response = None
return response
async def a_get_human_input(self, prompt: str) -> str:
if self.message_processor and self.a_human_input_function:
message_dict = {
"content": prompt,
"role": "system",
"type": "user-input-request"
}
message_payload = {
"recipient": self.name,
"sender": "system",
"message": message_dict,
"timestamp": datetime.now().isoformat(),
"sender_type": "system",
"connection_id": self.connection_id,
"message_type": "agent_message"
}
socket_msg = SocketMessage(
type="user_input_request",
data=message_payload,
connection_id=self.connection_id,
)
self.a_human_input_response = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout)
return self.a_human_input_response
else:
result = await super().a_get_human_input(prompt)
return result
class ExtendedGroupChatManager(autogen.GroupChatManager):
def __init__(self,
message_processor=None,
a_message_processor=None,
a_human_input_function=None,
a_human_input_timeout: Optional[int] = 60,
connection_id=None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.message_processor = message_processor
self.a_message_processor = a_message_processor
self.a_human_input_function = a_human_input_function
self.a_human_input_response = None
self.a_human_input_timeout = a_human_input_timeout
self.connection_id = connection_id
def receive(
self,
@ -925,3 +1022,39 @@ class ExtendedGroupChatManager(autogen.GroupChatManager):
await super().a_receive(message, sender, request_reply, silent)
def get_human_input(self, prompt: str) -> str:
if self.a_human_input_response == None:
return super().get_human_input(prompt)
else:
response = self.a_human_input_response
self.a_human_input_response = None
return response
async def a_get_human_input(self, prompt: str) -> str:
if self.message_processor and self.a_human_input_function:
message_dict = {
"content": prompt,
"role": "system",
"type": "user-input-request"
}
message_payload = {
"recipient": self.name,
"sender": "system",
"message": message_dict,
"timestamp": datetime.now().isoformat(),
"sender_type": "system",
"connection_id": self.connection_id,
"message_type": "agent_message"
}
socket_msg = SocketMessage(
type="user_input_request",
data=message_payload,
connection_id=self.connection_id,
)
result = await self.a_human_input_function(socket_msg.dict(), self.a_human_input_timeout)
return result
else:
result = await super().a_get_human_input(prompt)
return result

View File

@ -49,7 +49,7 @@ export const SectionHeader = ({
icon,
}: IProps) => {
return (
<div className="mb-4">
<div id="section-header" className="mb-4">
<h1 className="text-primary text-2xl">
{/* {count !== null && <span className="text-accent mr-1">{count}</span>} */}
{icon && <>{icon}</>}
@ -72,6 +72,7 @@ export const IconButton = ({
}: IProps) => {
return (
<span
id="icon-button"
role={"button"}
onClick={onClick}
className={`inline-block mr-2 hover:text-accent transition duration-300 ${className} ${
@ -90,6 +91,7 @@ export const LaunchButton = ({
}: any) => {
return (
<button
id="launch-button"
role={"button"}
className={` focus:ring ring-accent ring-l-none rounded cursor-pointer hover:brightness-110 bg-accent transition duration-500 text-white ${className} `}
onClick={onClick}
@ -102,6 +104,7 @@ export const LaunchButton = ({
export const SecondaryButton = ({ children, onClick, className }: any) => {
return (
<button
id="secondary-button"
role={"button"}
className={` ${className} focus:ring ring-accent p-2 px-5 rounded cursor-pointer hover:brightness-90 bg-secondary transition duration-500 text-primary`}
onClick={onClick}
@ -128,6 +131,7 @@ export const Card = ({
return (
<button
id="card"
tabIndex={0}
onClick={onClick}
role={"button"}
@ -157,6 +161,7 @@ export const CollapseBox = ({
const chevronClass = "h-4 cursor-pointer inline-block mr-1";
return (
<div
id="collapse-box"
onMouseDown={(e) => {
if (e.detail > 1) {
e.preventDefault();
@ -192,7 +197,7 @@ export const CollapseBox = ({
};
export const HighLight = ({ children }: IProps) => {
return <span className="border-b border-accent">{children}</span>;
return <span id="highlight" className="border-b border-accent">{children}</span>;
};
export const LoadBox = ({
@ -200,8 +205,7 @@ export const LoadBox = ({
className = "my-2 text-accent ",
}: IProps) => {
return (
<div className={`${className} `}>
{" "}
<div id="load-box" className={`${className} `}>
<span className="mr-2 ">
{" "}
<Icon size={5} icon="loading" />
@ -214,7 +218,7 @@ export const LoadBox = ({
export const LoadingBar = ({ children }: IProps) => {
return (
<>
<div className="rounded bg-secondary p-3">
<div id="loading-bar" className="rounded bg-secondary p-3">
<span className="inline-block h-6 w-6 relative mr-2">
<Cog8ToothIcon className="animate-ping text-accent absolute inline-flex h-full w-full rounded-ful opacity-75" />
<Cog8ToothIcon className="relative text-accent animate-spin inline-flex rounded-full h-6 w-6" />
@ -239,6 +243,7 @@ export const MessageBox = ({ title, children, className }: IProps) => {
return (
<div
id="message-box"
ref={messageBox}
className={`${className} p-3 rounded bg-secondary transition duration-1000 ease-in-out overflow-hidden`}
>
@ -272,7 +277,7 @@ export const GroupView = ({
className = "text-primary bg-primary ",
}: any) => {
return (
<div className={`rounded mt-4 border-secondary ${className}`}>
<div id="group-view" className={`rounded mt-4 border-secondary ${className}`}>
<div className="mt-4 p-2 rounded border relative">
<div className={`absolute -top-3 inline-block ${className}`}>
{title}
@ -297,6 +302,7 @@ export const ExpandView = ({
const minImageWidth = 400;
return (
<div
id="expand-view"
style={{
minHeight: "100px",
}}
@ -347,6 +353,7 @@ export const LoadingOverlay = ({ children, loading }: IProps) => {
{loading && (
<>
<div
id="loading-overlay"
className="absolute inset-0 bg-secondary flex pointer-events-none"
style={{ opacity: 0.5 }}
>
@ -376,7 +383,11 @@ export const MarkdownView = ({
showCode?: boolean;
}) => {
function processString(inputString: string): string {
inputString = inputString.replace(/\n/g, " \n");
// TODO: Had to add this temp measure while debugging. Why is it null?
if (!inputString) {
console.log("inputString is null!")
}
inputString = inputString && inputString.replace(/\n/g, " \n");
const markdownPattern = /```markdown\s+([\s\S]*?)\s+```/g;
return inputString?.replace(markdownPattern, (match, content) => content);
}
@ -449,6 +460,7 @@ export const MarkdownView = ({
return (
<div
id="markdown-view"
className={` w-full chatbox prose dark:prose-invert text-primary rounded ${className}`}
>
<ReactMarkdown
@ -499,7 +511,7 @@ export const CodeBlock = ({
const [showCopied, setShowCopied] = React.useState(false);
return (
<div className="relative">
<div id="code-block" className="relative">
<div className=" rounded absolute right-5 top-4 z-10 ">
<div className="relative border border-transparent w-full h-full">
<div
@ -566,7 +578,7 @@ export const ControlRowView = ({
truncateLength?: number;
}) => {
return (
<div className={`${className}`}>
<div id="control-row-view" className={`${className}`}>
<div>
<span className="text-primary inline-block">{title} </span>
<span className="text-xs ml-1 text-accent -mt-2 inline-block">
@ -590,7 +602,7 @@ export const BounceLoader = ({
title?: string;
}) => {
return (
<div className="inline-block">
<div id="bounce-loader" className="inline-block">
<div className="inline-flex gap-2">
<span className=" rounded-full bg-accent h-2 w-2 inline-block"></span>
<span className="animate-bounce rounded-full bg-accent h-3 w-3 inline-block"></span>
@ -611,10 +623,10 @@ export const ImageLoader = ({
const [isLoading, setIsLoading] = useState(true);
return (
<div className="w-full rounded relative">
<div id="image-loader" className="w-full rounded relative">
{isLoading && (
<div className="absolute h-24 inset-0 flex items-center justify-center">
<BounceLoader title=" loading .." />{" "}
<BounceLoader title=" loading .." />
</div>
)}
<img
@ -685,7 +697,7 @@ export const CsvLoader = ({
const scrollX = columns.length * 150;
return (
<div className={`CsvLoader ${className}`}>
<div id="csv-loader" className={`CsvLoader ${className}`}>
<Table
dataSource={data}
columns={columns}
@ -720,7 +732,7 @@ export const CodeLoader = ({
}, [url]);
return (
<div className={`w-full rounded relative ${className}`}>
<div id="code-loader" className={`w-full rounded relative ${className}`}>
{isLoading && (
<div className="absolute h-24 inset-0 flex items-center justify-center">
<BounceLoader />
@ -743,7 +755,7 @@ export const PdfViewer = ({ url }: { url: string }) => {
// Render the PDF viewer
return (
<div className="h-full">
<div id="pdf-viewer" className="h-full">
{loading && <p>Loading PDF...</p>}
{!loading && (
<object
@ -779,7 +791,7 @@ export const MonacoEditor = ({
setIsEditorReady(true);
};
return (
<div className="h-full rounded">
<div id="monaco-editor" className="h-full rounded">
<Editor
height="100%"
className="h-full rounded"
@ -820,6 +832,7 @@ export const CardHoverBar = ({
return (
<div
key={"cardhoverrow" + i}
id={`card-hover-bar-item-${i}`}
role="button"
className="text-accent text-xs inline-block hover:bg-primary p-2 rounded"
onClick={item.onClick}
@ -832,6 +845,7 @@ export const CardHoverBar = ({
});
return (
<div
id="card-hover-bar"
onMouseEnter={(e) => {
e.stopPropagation();
}}

View File

@ -193,8 +193,8 @@ export const AgentConfigView = ({
options={
[
{ label: "NEVER", value: "NEVER" },
// { label: "TERMINATE", value: "TERMINATE" },
// { label: "ALWAYS", value: "ALWAYS" },
{ label: "TERMINATE", value: "TERMINATE" },
{ label: "ALWAYS", value: "ALWAYS" },
] as any
}
/>

View File

@ -62,6 +62,10 @@ const ChatBox = ({
const [workflow, setWorkflow] = React.useState<IWorkflow | null>(null);
const [socketMessages, setSocketMessages] = React.useState<any[]>([]);
const [awaitingUserInput, setAwaitingUserInput] = React.useState(false); // New state for tracking user input
const setAreSessionButtonsDisabled = useConfigStore(
(state) => state.setAreSessionButtonsDisabled
);
const MAX_RETRIES = 10;
const RETRY_INTERVAL = 2000;
@ -102,7 +106,7 @@ const ChatBox = ({
try {
meta = JSON.parse(message.meta);
} catch (e) {
meta = message.meta;
meta = message?.meta;
}
const msg: IChatMessage = {
text: message.content,
@ -122,7 +126,6 @@ const ChatBox = ({
const initMsgs: IChatMessage[] = parseMessages(initMessages);
setMessages(initMsgs);
wsMessages.current = initMsgs;
socketMsgs = [];
}, [initMessages]);
const promptButtons = examplePrompts.map((prompt, i) => {
@ -141,7 +144,7 @@ const ChatBox = ({
);
});
const messageListView = messages?.map((message: IChatMessage, i: number) => {
const messageListView = messages && messages?.map((message: IChatMessage, i: number) => {
const isUser = message.sender === "user";
const css = isUser ? "bg-accent text-white " : "bg-light";
// console.log("message", message);
@ -209,7 +212,7 @@ const ChatBox = ({
return (
<div
className={`align-right ${isUser ? "text-righpt" : ""} mb-2 border-b`}
id={"message" + i} className={`align-right ${isUser ? "text-righpt" : ""} mb-2 border-b`}
key={"message" + i}
>
{" "}
@ -294,10 +297,10 @@ const ChatBox = ({
}, 500);
}, [messages]);
const textAreaDefaultHeight = "50px";
const textAreaDefaultHeight = "64px";
// clear text box if loading has just changed to false and there is no error
React.useEffect(() => {
if (loading === false && textAreaInputRef.current) {
if ((awaitingUserInput || loading === false) && textAreaInputRef.current) {
if (textAreaInputRef.current) {
if (error === null || (error && error.status === false)) {
textAreaInputRef.current.value = "";
@ -372,6 +375,19 @@ const ChatBox = ({
scrollChatBox(messageBoxInputRef);
}, 200);
// console.log("received message", data, socketMsgs.length);
} else if (data && data.type === "user_input_request") {
setAwaitingUserInput(true); // Set awaiting input state
textAreaInputRef.current.value = ""
textAreaInputRef.current.placeholder = data.data.message.content
const newsocketMessages = Object.assign([], socketMessages);
newsocketMessages.push(data.data);
setSocketMessages(newsocketMessages);
socketMsgs.push(data.data);
setTimeout(() => {
scrollChatBox(socketDivRef);
scrollChatBox(messageBoxInputRef);
}, 200);
ToastMessage.info(data.data.message)
} else if (data && data.type === "agent_status") {
// indicates a status message update
const agentStatusSpan = document.getElementById("agentstatusspan");
@ -380,6 +396,8 @@ const ChatBox = ({
}
} else if (data && data.type === "agent_response") {
// indicates a final agent response
setAwaitingUserInput(false); // Set awaiting input state
setAreSessionButtonsDisabled(false);
processAgentResponse(data.data);
}
};
@ -408,11 +426,13 @@ const ChatBox = ({
wsMessages.current.push(msg);
setMessages(wsMessages.current);
setLoading(false);
setAwaitingUserInput(false);
} else {
console.log("error", data);
// setError(data);
ToastMessage.error(data.message);
setLoading(false);
setAwaitingUserInput(false);
}
};
@ -442,6 +462,7 @@ const ChatBox = ({
const runWorkflow = (query: string) => {
setError(null);
setAreSessionButtonsDisabled(true);
socketMsgs = [];
let messageHolder = Object.assign([], messages);
@ -519,6 +540,43 @@ const ChatBox = ({
}
};
const sendUserResponse = (userResponse: string) => {
setAwaitingUserInput(false);
setError(null);
setLoading(true);
textAreaInputRef.current.placeholder = "Write message here..."
const userMessage: IChatMessage = {
text: userResponse,
sender: "system",
};
const messagePayload: IMessage = {
role: "user",
content: userResponse,
user_id: user?.email || "",
session_id: session?.id,
workflow_id: session?.workflow_id,
connection_id: connectionId,
};
// check if socket connected,
if (wsClient.current && wsClient.current.readyState === 1) {
wsClient.current.send(
JSON.stringify({
connection_id: connectionId,
data: messagePayload,
type: "user_message",
session_id: session?.id,
workflow_id: session?.workflow_id,
})
);
} else {
console.err("websocket client error")
}
};
const handleTextChange = (
event: React.ChangeEvent<HTMLTextAreaElement>
): void => {
@ -529,11 +587,16 @@ const ChatBox = ({
event: React.KeyboardEvent<HTMLTextAreaElement>
): void => {
if (event.key === "Enter" && !event.shiftKey) {
if (textAreaInputRef.current && !loading) {
if (textAreaInputRef.current &&(awaitingUserInput || !loading)) {
event.preventDefault();
if (awaitingUserInput) {
sendUserResponse(textAreaInputRef.current.value); // New function call for sending user input
textAreaInputRef.current.value = "";
} else {
runWorkflow(textAreaInputRef.current.value);
}
}
}
};
const getConnectionColor = (status: string) => {
@ -554,7 +617,7 @@ const ChatBox = ({
const WorkflowView = ({ workflow }: { workflow: IWorkflow }) => {
return (
<div className="text-xs cursor-pointer inline-block">
<div id="workflow-view" className="text-xs cursor-pointer inline-block">
{" "}
{workflow.name}
</div>
@ -563,11 +626,13 @@ const ChatBox = ({
return (
<div
id="chatbox-main"
style={{ height: "calc(100vh - " + heightOffset + "px)" }}
className="text-primary relative rounded "
ref={mainDivRef}
>
<div
id="workflow-name"
style={{ zIndex: 100 }}
className=" absolute right-3 bg-primary rounded text-secondary -top-6 p-2"
>
@ -576,17 +641,18 @@ const ChatBox = ({
</div>
<div
id="message-box"
ref={messageBoxInputRef}
className="flex h-full flex-col rounded scroll pr-2 overflow-auto "
style={{ minHeight: "30px", height: "calc(100vh - 310px)" }}
>
<div className="scroll-gradient h-10">
<div id="scroll-gradient" className="scroll-gradient h-10">
{" "}
<span className=" inline-block h-6"></span>{" "}
</div>
<div className="flex-1 boder mt-4"></div>
{!messages && messages !== null && (
<div className="w-full text-center boder mt-4">
<div id="loading-messages" className="w-full text-center boder mt-4">
<div>
{" "}
<BounceLoader />
@ -596,15 +662,15 @@ const ChatBox = ({
)}
{messages && messages?.length === 0 && (
<div className="ml-2 text-sm text-secondary ">
<div id="no-messages" className="ml-2 text-sm text-secondary ">
<InformationCircleIcon className="inline-block h-6 mr-2" />
No messages in the current session. Start a conversation to begin.
</div>
)}
<div className="ml-2"> {messageListView}</div>
{loading && (
<div className={` inline-flex gap-2 duration-300 `}>
<div id="message-list" className="ml-2"> {messageListView}</div>
{(loading || awaitingUserInput) && (
<div id="loading-bar" className={` inline-flex gap-2 duration-300 `}>
<div className=""></div>
<div className="font-semibold text-secondary text-sm w-16">
AGENTS
@ -632,6 +698,7 @@ const ChatBox = ({
{socketMsgs.length > 0 && (
<div
id="agent-messages"
ref={socketDivRef}
style={{
minHeight: "300px",
@ -661,10 +728,11 @@ const ChatBox = ({
)}
</div>
{editable && (
<div className="mt-2 p-2 absolute bg-primary bottom-0 w-full">
<div id="input-area" className="mt-2 p-2 absolute bg-primary bottom-0 w-full">
<div
id="input-form"
className={`rounded p-2 shadow-lg flex mb-1 gap-2 ${
loading ? " opacity-50 pointer-events-none" : ""
loading && !awaitingUserInput ? " opacity-50 pointer-events-none" : ""
}`}
>
{/* <input className="flex-1 p-2 ring-2" /> */}
@ -683,7 +751,7 @@ const ChatBox = ({
onChange={handleTextChange}
placeholder="Write message here..."
ref={textAreaInputRef}
className="flex items-center w-full resize-none text-gray-600 bg-white p-2 ring-2 rounded-sm pl-5 pr-16"
className="flex items-center w-full resize-none text-gray-600 bg-white p-2 ring-2 rounded-sm pl-5 pr-16 h-64"
style={{
maxHeight: "120px",
overflowY: "auto",
@ -691,23 +759,28 @@ const ChatBox = ({
}}
/>
<div
id="send-button"
role={"button"}
style={{ width: "45px", height: "35px" }}
title="Send message"
onClick={() => {
if (textAreaInputRef.current && !loading) {
if (textAreaInputRef.current && (awaitingUserInput || !loading)) {
if (awaitingUserInput) {
sendUserResponse(textAreaInputRef.current.value); // Use the new function for user input
} else {
runWorkflow(textAreaInputRef.current.value);
}
}
}}
className="absolute right-3 bottom-2 bg-accent hover:brightness-75 transition duration-300 rounded cursor-pointer flex justify-center items-center"
>
{" "}
{!loading && (
{(awaitingUserInput || !loading) && (
<div className="inline-block ">
<PaperAirplaneIcon className="h-6 w-6 text-white " />{" "}
</div>
)}
{loading && (
{loading && !awaitingUserInput && (
<div className="inline-block ">
<Cog6ToothIcon className="text-white animate-spin rounded-full h-6 w-6" />
</div>
@ -728,15 +801,16 @@ const ChatBox = ({
</div>
<div
id="prompt-buttons"
className={`mt-2 inline-flex gap-2 flex-wrap ${
loading ? "brightness-75 pointer-events-none" : ""
(loading && !awaitingUserInput) ? "brightness-75 pointer-events-none" : ""
}`}
>
{promptButtons}
</div>
</div>
{error && !error.status && (
<div className="p-2 rounded mt-4 text-orange-500 text-sm">
<div id="error-message" className="p-2 rounded mt-4 text-orange-500 text-sm">
{" "}
<ExclamationTriangleIcon className="h-5 text-orange-500 inline-block mr-2" />{" "}
{error.message}

View File

@ -51,6 +51,10 @@ const SessionsView = ({}: any) => {
const session = useConfigStore((state) => state.session);
const setSession = useConfigStore((state) => state.setSession);
const isSessionButtonsDisabled = useConfigStore(
(state) => state.areSessionButtonsDisabled
);
const deleteSession = (session: IChatSession) => {
setError(null);
setLoading(true);
@ -284,7 +288,9 @@ const SessionsView = ({}: any) => {
return (
<div
key={"sessionsrow" + index}
className="group relative mb-2 pb-1 border-b border-dashed "
className={`group relative mb-2 pb-1 border-b border-dashed ${
isSessionButtonsDisabled ? "opacity-50 pointer-events-none" : ""
}`}
>
{items.length > 0 && (
<div className=" absolute right-2 top-2 group-hover:opacity-100 opacity-0 ">
@ -295,8 +301,10 @@ const SessionsView = ({}: any) => {
className={`rounded p-2 cursor-pointer ${rowClass}`}
role="button"
onClick={() => {
setSession(data);
// setWorkflowConfig(data.flow_config);
if (!isSessionButtonsDisabled) {
setSession(data);
}
}}
>
<div className="text-xs mt-1">
@ -424,7 +432,7 @@ const SessionsView = ({}: any) => {
<div className="flex gap-x-2">
<div className="flex-1"></div>
<LaunchButton
className="text-sm p-2 px-3"
className={`text-sm p-2 px-3 ${isSessionButtonsDisabled ? 'opacity-50 cursor-not-allowed' : ''}`}
onClick={() => {
setSelectedSession(sampleSession);
setNewSessionModalVisible(true);

View File

@ -14,6 +14,8 @@ interface ConfigState {
setVersion: (version: string) => void;
connectionId: string;
setConnectionId: (connectionId: string) => void;
areSessionButtonsDisabled: boolean;
setAreSessionButtonsDisabled: (disabled: boolean) => void;
}
export const useConfigStore = create<ConfigState>()((set) => ({
@ -27,4 +29,6 @@ export const useConfigStore = create<ConfigState>()((set) => ({
setVersion: (version) => set({ version }),
connectionId: uuidv4(),
setConnectionId: (connectionId) => set({ connectionId }),
areSessionButtonsDisabled: false,
setAreSessionButtonsDisabled: (disabled) => set({ areSessionButtonsDisabled: disabled }),
}));