Teachability for any agent (#1091)

* Partial implementation

* Partial implementation

* Fixes

* update tests

* cleanup

* update tests

* comments

* logging

* wording

* underscore

* Extend notebook for teachable GPTAssistantAgent

* Notebook for teachable GPTAssistantAgents

* Update notebook

* Update notebook

* Update notebook

* Update notebook

* revert file

* Update blog post and other documentation.

* pre-commit

* Address reviewer feedback.

* Add new nb link to examples page.

---------

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Ricky Loynd 2024-01-07 11:59:42 -08:00 committed by GitHub
parent 172df5545d
commit 3680197f70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1715 additions and 468 deletions

View File

@ -155,11 +155,11 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install packages and dependencies for TeachableAgent
- name: Install packages and dependencies for Teachability
run: |
pip install -e .[teachable]
pip uninstall -y openai
- name: Test TeachableAgent
- name: Test Teachability
if: matrix.python-version != '3.9' # diversify the python versions
run: |
pytest test/agentchat/contrib/test_teachable_agent.py

2
.gitignore vendored
View File

@ -165,7 +165,7 @@ key_aoai.txt
base_aoai.txt
wolfram.txt
# DB on disk for TeachableAgent
# DB on disk for Teachability
tmp/
test/my_tmp/*

View File

@ -0,0 +1,15 @@
from autogen.agentchat.assistant_agent import ConversableAgent
class AgentCapability:
"""Base class for composable capabilities that can be added to an agent."""
def __init__(self):
pass
def add_to_agent(self, agent: ConversableAgent):
"""
Adds a particular capability to the given agent. Must be implemented by the capability subclass.
An implementation will typically call agent.register_hook() one or more times. See teachability.py as an example.
"""
raise NotImplementedError

View File

@ -1,13 +1,11 @@
import os
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import ConversableAgent
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability
from autogen.agentchat.contrib.text_analyzer_agent import TextAnalyzerAgent
from typing import Callable, Dict, Optional, Union, List, Tuple, Any
from typing import Dict, Optional, Union, List, Tuple, Any
import chromadb
from chromadb.config import Settings
import pickle
from tqdm import tqdm
try:
from termcolor import colored
@ -17,148 +15,113 @@ except ImportError:
return x
class TeachableAgent(ConversableAgent):
"""(Experimental) Teachable Agent, a subclass of ConversableAgent using a vector database to remember user teachings.
In this class, the term 'user' refers to any caller (human or not) sending messages to this agent.
Not yet tested in the group-chat setting."""
class Teachability(AgentCapability):
"""
Teachability uses a vector database to give an agent the ability to remember user teachings,
where the user is any caller (human or not) sending messages to the teachable agent.
Teachability is designed to be composable with other agent capabilities.
To make any conversable agent teachable, instantiate both the agent and the Teachability class,
then pass the agent to teachability.add_to_agent(agent).
Note that teachable agents in a group chat must be given unique path_to_db_dir values.
"""
def __init__(
self,
name="teachableagent",
system_message: Optional[
str
] = "You are a helpful AI assistant that remembers user teachings from prior chats.",
human_input_mode: Optional[str] = "NEVER",
verbosity: Optional[int] = 0,
reset_db: Optional[bool] = False,
path_to_db_dir: Optional[str] = "./tmp/teachable_agent_db",
recall_threshold: Optional[float] = 1.5,
max_num_retrievals: Optional[int] = 10,
llm_config: Optional[Union[Dict, bool]] = None,
analyzer_llm_config: Optional[Union[Dict, bool]] = None,
teach_config: Optional[Dict] = None,
**kwargs,
):
"""
Args:
name (str): name of the agent.
system_message (str): system message for the ChatCompletion inference.
human_input_mode (str): This agent should NEVER prompt the human for input.
llm_config (dict or False): llm inference configuration.
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
To disable llm-based auto reply, set to False.
analyzer_llm_config (dict or False): llm inference configuration passed to TextAnalyzerAgent.
Given the default setting of None, TeachableAgent passes its own llm_config to TextAnalyzerAgent.
teach_config (dict or None): Additional parameters used by TeachableAgent.
To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
- verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
- reset_db (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
- prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
- recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
- max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
reset_db (Optional, bool): True to clear the DB before starting. Default False.
path_to_db_dir (Optional, str): path to the directory where this particular agent's DB is stored. Default "./tmp/teachable_agent_db"
recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
llm_config (dict or False): llm inference configuration passed to TextAnalyzerAgent.
If None, TextAnalyzerAgent uses llm_config from the teachable agent.
"""
super().__init__(
name=name,
system_message=system_message,
human_input_mode=human_input_mode,
llm_config=llm_config,
**kwargs,
)
# Register a custom reply function.
self.register_reply(Agent, TeachableAgent._generate_teachable_assistant_reply, position=2)
self.verbosity = verbosity
self.path_to_db_dir = path_to_db_dir
self.recall_threshold = recall_threshold
self.max_num_retrievals = max_num_retrievals
self.llm_config = llm_config
# Assemble the parameter settings.
self._teach_config = {} if teach_config is None else teach_config
self.verbosity = self._teach_config.get("verbosity", 0)
self.reset_db = self._teach_config.get("reset_db", False)
self.path_to_db_dir = self._teach_config.get("path_to_db_dir", "./tmp/teachable_agent_db")
self.prepopulate = self._teach_config.get("prepopulate", True)
self.recall_threshold = self._teach_config.get("recall_threshold", 1.5)
self.max_num_retrievals = self._teach_config.get("max_num_retrievals", 10)
# Create the analyzer.
if analyzer_llm_config is None:
analyzer_llm_config = llm_config
self.analyzer = TextAnalyzerAgent(llm_config=analyzer_llm_config)
self.analyzer = None
self.teachable_agent = None
# Create the memo store.
self.memo_store = MemoStore(self.verbosity, self.reset_db, self.path_to_db_dir)
self.user_comments = [] # Stores user comments until the end of each chat.
self.memo_store = MemoStore(self.verbosity, reset_db, self.path_to_db_dir)
def close_db(self):
"""Cleanly closes the memo store."""
self.memo_store.close()
def add_to_agent(self, agent: ConversableAgent):
"""Adds teachability to the given agent."""
self.teachable_agent = agent
# Register a hook for processing the last message.
agent.register_hook(hookable_method=agent.process_last_message, hook=self.process_last_message)
# Was an llm_config passed to the constructor?
if self.llm_config is None:
# No. Use the agent's llm_config.
self.llm_config = agent.llm_config
assert self.llm_config, "Teachability requires a valid llm_config."
# Create the analyzer agent.
self.analyzer = TextAnalyzerAgent(llm_config=self.llm_config)
# Append extra info to the system message.
agent.update_system_message(
agent.system_message
+ "\nYou've been given the special ability to remember user teachings from prior conversations."
)
def prepopulate_db(self):
"""Adds a few arbitrary memos to the DB."""
self.memo_store.prepopulate()
def _generate_teachable_assistant_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None, # Persistent state.
) -> Tuple[bool, Union[str, Dict, None]]:
def process_last_message(self, text):
"""
Generates a reply to the last user message, after querying the memo store for relevant information.
Appends any relevant memos to the message text, and stores any apparent teachings in new memos.
Uses TextAnalyzerAgent to make decisions about memo storage and retrieval.
"""
if self.llm_config is False:
raise ValueError("TeachableAgent requires self.llm_config to be set in its base class.")
if messages is None:
messages = self._oai_messages[sender] # In case of a direct call.
# Get the last user turn.
last_message = messages[-1]
user_text = last_message["content"]
if (not isinstance(user_text, str)) or ("context" in last_message):
raise ValueError(
"TeachableAgent currently assumes that the message content is a simple string. This error serves to flag a test case for relaxing this assumption."
)
# Keep track of this user turn as a potential source of memos later.
self.user_comments.append(user_text)
# Consider whether to retrieve something from the DB.
# Try to retrieve relevant memos from the DB.
expanded_text = text
if self.memo_store.last_memo_id > 0:
new_user_text = self.consider_memo_retrieval(user_text)
if new_user_text != user_text:
# Make a copy of the message list, and replace the last user message with the new one.
messages = messages.copy()
messages[-1]["content"] = new_user_text
expanded_text = self._consider_memo_retrieval(text)
# Generate a response by reusing existing generate_oai_reply
return self.generate_oai_reply(messages, sender, config)
# Try to store any user teachings in new memos to be used in the future.
self._consider_memo_storage(text)
def learn_from_user_feedback(self):
"""Reviews the user comments from the last chat, and decides what teachings to store as memos."""
print(colored("\nREVIEWING CHAT FOR USER TEACHINGS TO REMEMBER", "light_yellow"))
# Look at each user turn.
if len(self.user_comments) > 0:
for comment in tqdm(self.user_comments):
# Consider whether to store something from this user turn in the DB.
self.consider_memo_storage(comment)
self.user_comments = []
# Return the (possibly) expanded message text.
return expanded_text
def consider_memo_storage(self, comment):
def _consider_memo_storage(self, comment):
"""Decides whether to store something from one user comment in the DB."""
memo_added = False
# Check for a problem-solution pair.
response = self.analyze(
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Can we extract advice?
advice = self.analyze(
advice = self._analyze(
comment,
"Briefly copy any advice from the TEXT that may be useful for a similar but different task in the future. But if no advice is present, just respond with 'none'.",
)
if "none" not in advice.lower():
# Yes. Extract the task.
task = self.analyze(
task = self._analyze(
comment,
"Briefly copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice.",
)
# Generalize the task.
general_task = self.analyze(
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
@ -166,37 +129,44 @@ class TeachableAgent(ConversableAgent):
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS TASK-ADVICE PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(general_task, advice)
memo_added = True
# Check for information to be learned.
response = self.analyze(
response = self._analyze(
comment,
"Does the TEXT contain information that could be committed to memory? Answer with just one word, yes or no.",
)
if "yes" in response.lower():
# Yes. What question would this information answer?
question = self.analyze(
question = self._analyze(
comment,
"Imagine that the user forgot this information in the TEXT. How would they ask you for this information? Include no other text in your response.",
)
# Extract the information.
answer = self.analyze(
answer = self._analyze(
comment, "Copy the information from the TEXT that should be committed to memory. Add no explanation."
)
# Add the question-answer pair to the vector DB.
if self.verbosity >= 1:
print(colored("\nREMEMBER THIS QUESTION-ANSWER PAIR", "light_yellow"))
self.memo_store.add_input_output_pair(question, answer)
memo_added = True
def consider_memo_retrieval(self, comment):
# Were any memos added?
if memo_added:
# Yes. Save them to disk.
self.memo_store._save_memos()
def _consider_memo_retrieval(self, comment):
"""Decides whether to retrieve memos from the DB, and add them to the chat context."""
# First, use the user comment directly as the lookup key.
# First, use the comment directly as the lookup key.
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS QUESTION-ANSWER PAIRS", "light_yellow"))
memo_list = self.retrieve_relevant_memos(comment)
memo_list = self._retrieve_relevant_memos(comment)
# Next, if the comment involves a task, then extract and generalize the task before using it as the lookup key.
response = self.analyze(
response = self._analyze(
comment,
"Does any part of the TEXT ask the agent to perform a task or solve a problem? Answer with just one word, yes or no.",
)
@ -204,24 +174,24 @@ class TeachableAgent(ConversableAgent):
if self.verbosity >= 1:
print(colored("\nLOOK FOR RELEVANT MEMOS, AS TASK-ADVICE PAIRS", "light_yellow"))
# Extract the task.
task = self.analyze(
task = self._analyze(
comment, "Copy just the task from the TEXT, then stop. Don't solve it, and don't include any advice."
)
# Generalize the task.
general_task = self.analyze(
general_task = self._analyze(
task,
"Summarize very briefly, in general terms, the type of task described in the TEXT. Leave out details that might not appear in a similar problem.",
)
# Append any relevant memos.
memo_list.extend(self.retrieve_relevant_memos(general_task))
memo_list.extend(self._retrieve_relevant_memos(general_task))
# De-duplicate the memo list.
memo_list = list(set(memo_list))
# Append the memos to the last user message.
return comment + self.concatenate_memo_texts(memo_list)
# Append the memos to the text of the last message.
return comment + self._concatenate_memo_texts(memo_list)
def retrieve_relevant_memos(self, input_text):
def _retrieve_relevant_memos(self, input_text):
"""Returns semantically related memos from the DB."""
memo_list = self.memo_store.get_related_memos(
input_text, n_results=self.max_num_retrievals, threshold=self.recall_threshold
@ -239,7 +209,7 @@ class TeachableAgent(ConversableAgent):
memo_list = [memo[1] for memo in memo_list]
return memo_list
def concatenate_memo_texts(self, memo_list):
def _concatenate_memo_texts(self, memo_list):
"""Concatenates the memo texts into a single string for inclusion in the chat context."""
memo_texts = ""
if len(memo_list) > 0:
@ -247,30 +217,25 @@ class TeachableAgent(ConversableAgent):
for memo in memo_list:
info = info + "- " + memo + "\n"
if self.verbosity >= 1:
print(colored("\nMEMOS APPENDED TO LAST USER MESSAGE...\n" + info + "\n", "light_yellow"))
print(colored("\nMEMOS APPENDED TO LAST MESSAGE...\n" + info + "\n", "light_yellow"))
memo_texts = memo_texts + "\n" + info
return memo_texts
def analyze(self, text_to_analyze, analysis_instructions):
def _analyze(self, text_to_analyze, analysis_instructions):
"""Asks TextAnalyzerAgent to analyze the given text according to specific instructions."""
if self.verbosity >= 2:
# Use the messaging mechanism so that the analyzer's messages are included in the printed chat.
self.analyzer.reset() # Clear the analyzer's list of messages.
self.send(
recipient=self.analyzer, message=text_to_analyze, request_reply=False
) # Put the message in the analyzer's list.
self.send(recipient=self.analyzer, message=analysis_instructions, request_reply=True) # Request the reply.
return self.last_message(self.analyzer)["content"]
else:
# TODO: This is not an encouraged usage pattern. It breaks the conversation-centric design.
# consider using the arg "silent"
# Use the analyzer's method directly, to leave analyzer message out of the printed chat.
return self.analyzer.analyze_text(text_to_analyze, analysis_instructions)
self.analyzer.reset() # Clear the analyzer's list of messages.
self.teachable_agent.send(
recipient=self.analyzer, message=text_to_analyze, request_reply=False, silent=(self.verbosity < 2)
) # Put the message in the analyzer's list.
self.teachable_agent.send(
recipient=self.analyzer, message=analysis_instructions, request_reply=True, silent=(self.verbosity < 2)
) # Request the reply.
return self.teachable_agent.last_message(self.analyzer)["content"]
class MemoStore:
"""(Experimental)
Provides memory storage and retrieval for a TeachableAgent, using a vector database.
"""
Provides memory storage and retrieval for a teachable agent, using a vector database.
Each DB entry (called a memo) is a pair of strings: an input text and an output text.
The input text might be a question, or a task to perform.
The output text might be an answer to the question, or advice on how to perform the task.
@ -284,7 +249,6 @@ class MemoStore:
- path_to_db_dir (Optional, str): path to the directory where the DB is stored.
"""
self.verbosity = verbosity
self.reset = reset
self.path_to_db_dir = path_to_db_dir
# Load or create the vector DB on disk.
@ -293,8 +257,6 @@ class MemoStore:
)
self.db_client = chromadb.Client(settings)
self.vec_db = self.db_client.create_collection("memos", get_or_create=True) # The collection is the DB.
if reset:
self.reset_db()
# Load or create the associated memo dict on disk.
self.path_to_dict = os.path.join(path_to_db_dir, "uid_text_dict.pkl")
@ -309,6 +271,10 @@ class MemoStore:
if self.verbosity >= 3:
self.list_memos()
# Clear the DB if requested.
if reset:
self.reset_db()
def list_memos(self):
"""Prints the contents of MemoStore."""
print(colored("LIST OF MEMOS", "light_green"))
@ -321,10 +287,8 @@ class MemoStore:
)
)
def close(self):
def _save_memos(self):
"""Saves self.uid_text_dict to disk."""
print(colored("\nSAVING MEMORY TO DISK", "light_green"))
print(colored(" Location = {}".format(self.path_to_dict), "light_green"))
with open(self.path_to_dict, "wb") as file:
pickle.dump(self.uid_text_dict, file)
@ -334,6 +298,7 @@ class MemoStore:
self.db_client.delete_collection("memos")
self.vec_db = self.db_client.create_collection("memos")
self.uid_text_dict = {}
self._save_memos()
def add_input_output_pair(self, input_text, output_text):
"""Adds an input-output pair to the vector DB."""
@ -343,10 +308,10 @@ class MemoStore:
if self.verbosity >= 1:
print(
colored(
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}".format(
"\nINPUT-OUTPUT PAIR ADDED TO VECTOR DATABASE:\n ID\n {}\n INPUT\n {}\n OUTPUT\n {}\n".format(
self.last_memo_id, input_text, output_text
),
"light_green",
"light_yellow",
)
)
if self.verbosity >= 3:
@ -364,7 +329,7 @@ class MemoStore:
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
input_text, output_text, distance
),
"light_green",
"light_yellow",
)
)
return input_text, output_text, distance
@ -387,7 +352,7 @@ class MemoStore:
"\nINPUT-OUTPUT PAIR RETRIEVED FROM VECTOR DATABASE:\n INPUT1\n {}\n OUTPUT\n {}\n DISTANCE\n {}".format(
input_text, output_text, distance
),
"light_green",
"light_yellow",
)
)
memos.append((input_text, output_text, distance))
@ -422,3 +387,4 @@ class MemoStore:
)
for example in examples:
self.add_input_output_pair(example["text"], example["label"])
self._save_memos()

View File

@ -29,14 +29,6 @@ class TextAnalyzerAgent(ConversableAgent):
Please refer to [OpenAIWrapper.create](/docs/reference/oai/client#create)
for available options.
To disable llm-based auto reply, set to False.
teach_config (dict or None): Additional parameters used by TeachableAgent.
To use default config, set to None. Otherwise, set to a dictionary with any of the following keys:
- verbosity (Optional, int): # 0 (default) for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
- reset_db (Optional, bool): True to clear the DB before starting. Default False.
- path_to_db_dir (Optional, str): path to the directory where the DB is stored. Default "./tmp/teachable_agent_db"
- prepopulate (Optional, int): True (default) to prepopulate the DB with a set of input-output pairs.
- recall_threshold (Optional, float): The maximum distance for retrieved memos, where 0.0 is exact match. Default 1.5. Larger values allow more (but less relevant) memos to be recalled.
- max_num_retrievals (Optional, int): The maximum number of memos to retrieve from the DB. Default 10.
**kwargs (dict): other kwargs in [ConversableAgent](../conversable_agent#__init__).
"""
super().__init__(
@ -56,7 +48,7 @@ class TextAnalyzerAgent(ConversableAgent):
) -> Tuple[bool, Union[str, Dict, None]]:
"""Analyzes the given text as instructed, and returns the analysis as a message.
Assumes exactly two messages containing the text to analyze and the analysis instructions.
See TeachableAgent.analyze for an example of how to use this method."""
See Teachability.analyze for an example of how to use this method."""
if self.llm_config is False:
raise ValueError("TextAnalyzerAgent requires self.llm_config to be set in its base class.")
if messages is None:

View File

@ -152,6 +152,10 @@ class ConversableAgent(Agent):
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
# New hookable methods should be added to this list as required to support new agent capabilities.
self.hook_lists = {self.process_last_message: []} # This is currently the only hookable method.
def register_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
@ -757,7 +761,7 @@ class ConversableAgent(Agent):
else:
messages_to_scan += 1
# iterate through the last n messages reversely
# iterate through the last n messages in reverse
# if code blocks are found, execute the code blocks and return the output
# if no code blocks are found, continue
for i in range(min(len(messages), messages_to_scan)):
@ -1173,6 +1177,10 @@ class ConversableAgent(Agent):
if messages is None:
messages = self._oai_messages[sender]
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
@ -1225,6 +1233,10 @@ class ConversableAgent(Agent):
if messages is None:
messages = self._oai_messages[sender]
# Call the hookable method that gives registered hooks a chance to process the last message.
# Message modifications do not affect the incoming messages or self._oai_messages.
messages = self.process_last_message(messages)
for reply_func_tuple in self._reply_func_list:
reply_func = reply_func_tuple["reply_func"]
if exclude and reply_func in exclude:
@ -1757,3 +1769,56 @@ class ConversableAgent(Agent):
return func
return _decorator
def register_hook(self, hookable_method: Callable, hook: Callable):
"""
Registers a hook to be called by a hookable method, in order to add a capability to the agent.
Registered hooks are kept in lists (one per hookable method), and are called in their order of registration.
Args:
hookable_method: A hookable method implemented by ConversableAgent.
hook: A method implemented by a subclass of AgentCapability.
"""
assert hookable_method in self.hook_lists, f"{hookable_method} is not a hookable method."
hook_list = self.hook_lists[hookable_method]
assert hook not in hook_list, f"{hook} is already registered as a hook."
hook_list.append(hook)
def process_last_message(self, messages):
"""
Calls any registered capability hooks to use and potentially modify the text of the last message,
as long as the last message is not a function call or exit command.
"""
# If any required condition is not met, return the original message list.
hook_list = self.hook_lists[self.process_last_message]
if len(hook_list) == 0:
return messages # No hooks registered.
if messages is None:
return None # No message to process.
if len(messages) == 0:
return messages # No message to process.
last_message = messages[-1]
if "function_call" in last_message:
return messages # Last message is a function call.
if "context" in last_message:
return messages # Last message contains a context key.
if "content" not in last_message:
return messages # Last message has no content.
user_text = last_message["content"]
if not isinstance(user_text, str):
return messages # Last message content is not a string. TODO: Multimodal agents will use a dict here.
if user_text == "exit":
return messages # Last message is an exit command.
# Call each hook (in order of registration) to process the user's message.
processed_user_text = user_text
for hook in hook_list:
processed_user_text = hook(processed_user_text)
if processed_user_text == user_text:
return messages # No hooks actually modified the user's message.
# Replace the last user message with the expanded one.
messages = messages.copy()
messages[-1]["content"] = processed_user_text
return messages

View File

@ -13,15 +13,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Chatting with TeachableAgent\n",
"# Chatting with a teachable agent\n",
"\n",
"Conversational assistants based on LLMs can remember the current chat with the user, and can even demonstrate in-context learning of things that the user teaches the assistant during the chat. But these memories and learnings are lost once the chat is over, or when a single chat grows too long for the LLM to handle effectively. In subsequent chats, the user is forced to repeat any necessary instructions over and over.\n",
"\n",
"`TeachableAgent` addresses these limitations by persisting user teachings across chat boundaries in long-term memory (a vector database). Memory is saved to disk at the end of each chat, then loaded from disk at the start of the next. Instead of copying all of memory into the context window, which would eat up valuable space, individual memories (called memos) are retrieved into context as needed. This allows the user to teach frequently used facts and skills to the teachable agent just once, and have it remember them in later chats.\n",
"The optional agent capability called `Teachability` addresses these limitations by persisting user teachings across chat boundaries in long-term memory (a vector database). Memories (called memos) are created and saved to disk throughout a conversation, then loaded from disk later. Instead of copying all the memos into the context window, which would eat up valuable space, individual memos are retrieved into context only as needed. This allows the user to teach many facts, preferences and skills to the teachable agent just once, and have it remember them in later chats.\n",
"\n",
"In making decisions about memo storage and retrieval, `TeachableAgent` calls an instance of `TextAnalyzerAgent` to analyze pieces of text in several different ways. This adds extra LLM calls involving a relatively small number of tokens. These calls can add a few seconds to the time a user waits for a response.\n",
"In making decisions about memo storage and retrieval, `Teachability` calls an instance of `TextAnalyzerAgent` to analyze pieces of text in several different ways. This adds extra LLM calls involving a relatively small number of tokens. These calls can add a few seconds to the time a user waits for a response.\n",
"\n",
"This notebook demonstrates how `TeachableAgent` can learn facts, preferences, and skills from users. To chat with `TeachableAgent` yourself, run [chat_with_teachable_agent.py](../test/agentchat/contrib/chat_with_teachable_agent.py).\n",
"This notebook demonstrates how `Teachability` can be added to an agent so that it can learn facts, preferences, and skills from users. To chat with a teachable agent yourself, run [chat_with_teachable_agent.py](../test/agentchat/contrib/chat_with_teachable_agent.py).\n",
"\n",
"## Requirements\n",
"\n",
@ -60,7 +60,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"gpt-4\n"
"gpt-4-1106-preview\n"
]
}
],
@ -71,7 +71,7 @@
" env_or_file=\"OAI_CONFIG_LIST\",\n",
" file_location=\".\",\n",
" filter_dict={\n",
" \"model\": [\"gpt-4\", \"gpt4\", \"gpt-4-32k\"],\n",
" \"model\": [\"gpt-4\", \"gpt-4-1106-preview\", \"gpt4\", \"gpt-4-32k\"],\n",
" },\n",
")\n",
"\n",
@ -83,12 +83,16 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). After application of this particular filter, only the gpt-4 models are kept.\n",
"It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). After application of the filter shown above, only the gpt-4 models are considered.\n",
"\n",
"The config list looks like the following:\n",
"The config list may look like the following:\n",
"```python\n",
"config_list = [\n",
" {\n",
" 'model': 'gpt-4-1106-preview',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" },\n",
" {\n",
" 'model': 'gpt-4',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" },\n",
@ -120,7 +124,7 @@
"metadata": {},
"source": [
"## Construct Agents\n",
"For this walkthrough, we start by resetting the teachable agent's memory store. This deletes any memories from prior conversations that may be stored on disk."
"For this walkthrough, we start by creating a teachable agent and resetting its memory store. This deletes any memories from prior conversations that may be stored on disk."
]
},
{
@ -138,21 +142,25 @@
}
],
"source": [
"from autogen.agentchat.contrib.teachable_agent import TeachableAgent\n",
"from autogen import UserProxyAgent\n",
"from autogen.agentchat.contrib.capabilities.teachability import Teachability\n",
"from autogen import ConversableAgent, UserProxyAgent\n",
"\n",
"llm_config = {\n",
" \"config_list\": config_list,\n",
" \"timeout\": 60,\n",
" \"cache_seed\": None, # Use an int to seed the response cache. Use None to disable caching.\n",
"}\n",
"# Start by instantiating any agent that inherits from ConversableAgent.\n",
"teachable_agent = ConversableAgent(\n",
" name=\"teachable_agent\", # The name is flexible, but should not contain spaces to work in group chat.\n",
" llm_config={\"config_list\": config_list, \"timeout\": 120, \"cache_seed\": None}, # Disable caching.\n",
")\n",
"\n",
"teach_config={\n",
" \"verbosity\": 0, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.\n",
" \"reset_db\": True, # Set to True to start over with an empty database.\n",
" \"path_to_db_dir\": \"./tmp/notebook/teachable_agent_db\", # Path to the directory where the database will be stored.\n",
" \"recall_threshold\": 1.5, # Higher numbers allow more (but less relevant) memos to be recalled.\n",
"}\n",
"# Instantiate the Teachability capability. Its parameters are all optional.\n",
"teachability = Teachability(\n",
" verbosity=0, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.\n",
" reset_db=True,\n",
" path_to_db_dir=\"./tmp/notebook/teachability_db\",\n",
" recall_threshold=1.5, # Higher numbers allow more (but less relevant) memos to be recalled.\n",
")\n",
"\n",
"# Now add the Teachability capability to the agent.\n",
"teachability.add_to_agent(teachable_agent)\n",
"\n",
"try:\n",
" from termcolor import colored\n",
@ -160,12 +168,8 @@
"\n",
" def colored(x, *args, **kwargs):\n",
" return x\n",
" \n",
"teachable_agent = TeachableAgent(\n",
" name=\"teachableagent\",\n",
" llm_config=llm_config,\n",
" teach_config=teach_config)\n",
"\n",
"# Instantiate a UserProxyAgent to represent the user. But in this notebook, all user input will be simulated.\n",
"user = UserProxyAgent(\n",
" name=\"user\",\n",
" human_input_mode=\"NEVER\",\n",
@ -192,14 +196,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"What is the Vicuna model?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"I'm sorry, but I don't have information about the \"Vicuna model\" in my current database. Could you please provide some context, description, or details about this model so I can better assist you?\n",
"The term \"Vicuna model\" does not point to a well-known concept or framework in the realms of science, technology, or social sciences as of my last knowledge update in early 2023. It's possible that the term could be a reference to a proprietary model or a concept that has emerged after my last update or it might be a misspelling or a misunderstanding.\n",
"\n",
"If you are referring to \"Vicuña,\" you might be speaking about the animal. The vicuña is a wild South American camelid, which lives in the high alpine areas of the Andes. Vicuñas are relatives of the llama and the alpaca, and they are known for producing extremely fine wool. They were once hunted almost to extinction for their wool but have since been protected and their population has recovered.\n",
"\n",
"If you're referencing something specific, such as a model within a particular field or a term from a proprietary or niche subject, please provide more context or clarify, and I would be happy to help to the best of my ability with the information provided.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -219,14 +229,24 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Vicuna is a 13B-parameter language model released by Meta.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Thank you for providing the information. The Vicuna model is a large-scale language model developed by Meta. It consists of 13 billion parameters, which is what makes it able to understand and generate human-like text across various topics. Parameters in a neural network are the weights and biases assigned to each connection between the neurons, and having more parameters enables the model to learn and capture more complex patterns and relationships in data. As an AI language model, Vicuna can be utilized in various language-related tasks, such as text completion, translation, summarization, and more.\n",
"My apologies for the confusion. As of my last update, the Vicuna model had not been part of my database. If Vicuna is indeed a 13-billion-parameter language model developed by Meta (formerly Facebook Inc.), then it would be one of the large-scale transformer-based models akin to those like GPT-3 by OpenAI.\n",
"\n",
"For context, a language model like Vicuna, if it exists and as described, would be designed to perform a wide range of natural language processing tasks, such as translation, question answering, and text generation. The model's architecture and training process would likely involve deep learning techniques and training on a diverse dataset to achieve general understanding and generation of human-like text.\n",
"\n",
"The size of the model, measured in the number of parameters (13 billion in this case), would suggest its capability to handle complex language tasks and subtleties. With such a large number of parameters, it would likely be able to generate highly coherent and contextually relevant text.\n",
"\n",
"These language models are trained on vast amounts of text data and require significant computational power for both training and inference processes. Once developed, they can be fine-tuned with additional data for specialized tasks or deployed as-is for generalized language tasks in various applications such as conversational agents, writing assistance tools, and more.\n",
"\n",
"Since I don't have real-time access to data and there might have been advancements or releases after my last update in early 2023, I recommend checking the latest sources or official announcements by Meta for the most accurate and detailed information about the Vicuna language model.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -246,16 +266,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"What is the Orca model?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"The Orca model is a yet-to-be-released language model by Meta, referring to the Research Preview API release of Meta's advanced language model. While I don't have specific details about the Orca model at this time, it is expected that the model would have an even larger number of parameters to further enhance its capabilities compared to its predecessor, Vicuna.\n",
"As of my last update, the Orca model appears to reference a new development that I do not have extensive information on, similar to the earlier reference to the Vicuna model.\n",
"\n",
"As AI language models evolve, they generally become more efficient at understanding and generating human-like text. This makes them more effective in completing a wide range of tasks, including text completion, translation, summarization, and Q&A, as well as providing more accurate and coherent results in various applications.\n",
"Based on the pattern and context provided, if Orca is indeed a language model like Vicuna, it could potentially be another large-scale, transformer-based language model developed for various natural language processing tasks. Generally, these models are trained on extensive corpuses of text to be able to generate text, translate languages, answer questions, and more.\n",
"\n",
"However, I don't have specific details regarding an \"Orca model\" from Meta or any other organization. Without updated information, my insights are limited. To learn more about developments like the Orca or Vicuna models in language technology, I encourage you to consult the latest research publications or announcements from Meta or other organizations involved in NLP and AI. They are often discussed in academic papers, technology news articles, and official press releases from the companies that created them.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -275,23 +299,29 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Orca is a 13B-parameter language model released by Microsoft. It outperforms Vicuna on most tasks.\n",
"Orca is a 13B-parameter language model developed by Microsoft. It outperforms Vicuna on most tasks.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Thank you for the clarification. Orca is a 13-billion-parameter language model developed by Microsoft. It is designed to outperform Vicuna, the language model released by Meta. With a large number of parameters, Orca is capable of understanding and generating human-like text across diverse subjects, making it highly useful in various applications.\n",
"Thank you for providing the context about the Orca model. Based on the new information you've given, Orca is a language model with 13 billion parameters, similar in size to Meta's Vicuna model, but developed by Microsoft. If it outperforms Vicuna on most tasks, it suggests that it could have been trained on a more diverse dataset, use a more advanced architecture, have more effective training techniques, or some combination of these factors.\n",
"\n",
"Orca's advanced capabilities allow it to excel in tasks such as text completion, translation, summarization, Q&A, and more. Its performance improvement over Vicuna highlights the rapid advancements in natural language processing (NLP) research and the development of increasingly capable AI language models.\n",
"Large language models like Orca and Vicuna are used to perform a variety of complex natural language understanding and generation tasks. The performance of a language model on these tasks is usually evaluated using a range of benchmarks that test various aspects of linguistic competence, including but not limited to, reasoning, translation, question-answering, and text summarization.\n",
"\n",
"Given that Orca reportedly outperforms Vicuna, Microsoft might have employed new advances in training methods or model architectures to increase the efficiency and effectiveness of the model. Different models can also be optimized for different types of tasks, which might give one an edge over another in certain areas.\n",
"\n",
"For detailed and accurate information about the capabilities and specific aspects of the Orca model, you would need to look at the technical documentation or papers released by Microsoft, which would typically contain comprehensive benchmarks and evaluations compared to other models, including Vicuna. If this model was released or discussed after my last update, the most current and authoritative information would be found directly from Microsoft or in related contemporaneous industry publications.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"text = \"Orca is a 13B-parameter language model released by Microsoft. It outperforms Vicuna on most tasks.\"\n",
"text = \"Orca is a 13B-parameter language model developed by Microsoft. It outperforms Vicuna on most tasks.\"\n",
"user.initiate_chat(teachable_agent, message=text, clear_history=False)"
]
},
@ -299,7 +329,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's end our first chat here. The following function needs to be called at the end of each chat, so that `TeachableAgent` can store what the user has taught it."
"Let's end our first chat here, and start a new chat by clearing the previous chat's history, by passing `clear_history=True` to `initiate_chat`. At this point, a common LLM-based assistant would forget everything from the last chat. But a teachable agent can retrieve memories from its vector DB as needed, allowing it to recall and reason over things that the user taught it in earlier conversations."
]
},
{
@ -311,39 +341,22 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[93m\n",
"REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n"
]
}
],
"source": [
"teachable_agent.learn_from_user_feedback()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's start a new chat by clearing the previous chat's history. At this point, common LLM-based assistants would forget everything from the last chat. But `TeachableAgent` can retrieve memories from its vector DB as needed, allowing it to recall and reason over facts that the user taught it in earlier conversations."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"How does the Vicuna model compare to the Orca model?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"The Vicuna and Orca models are both 13B-parameter language models released by Meta and Microsoft, respectively. In terms of performance, Orca has been reported to outperform Vicuna on most tasks. However, without detailed information about specific tasks and benchmarks, it is difficult to provide a more comprehensive comparison. Generally speaking, both models are advanced language models that aim to provide high-quality natural language understanding and generation, but Orca appears to have an edge in terms of overall performance.\n",
"The Vicuna model and the Orca model are both large-scale language models with a significant number of parameters—13 billion, to be exact.\n",
"\n",
"The Vicuna model was developed by Meta (formerly Facebook). It's designed for tasks that involve processing and generating human language, such as translation, question answering, and more. Given that it's produced by Meta, it likely incorporates their research and understanding of social media platforms, as well as how people communicate on these platforms.\n",
"\n",
"The Orca model, developed by Microsoft, is also geared for similar natural language processing tasks. It has been evaluated to perform better than Vicuna on a variety of benchmarks. Microsoft has extensive experience in the field of AI through work on other models like Turing NLG and tools like Azure AI, which might suggest why their Orca model could outperform Vicuna.\n",
"\n",
"Both models represent the cutting edge of AI language processing as of the time of our last update; continued advances in this area happen rapidly, and newer models or updates to these models might have been released since. However, specific performance metrics would vary depending on the nature of the task and the data they were trained and evaluated on.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -364,14 +377,14 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Please summarize this abstract.\n",
"\n",
@ -381,9 +394,11 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"AutoGen is an open-source framework designed to enable developers to create LLM applications with multiple customizable agents that can converse with each other to complete tasks. These agents can operate using different combinations of LLMs, human inputs, and tools, allowing developers to define agent interaction behaviors flexibly. AutoGen supports programming flexible conversation patterns using both natural language and code, making it suitable for building diverse applications with varying complexities and LLM capacities. Its effectiveness has been demonstrated through empirical studies across various domains including mathematics, coding, operations research, decision-making, and entertainment.\n",
"AutoGen is an open-source framework designed to facilitate the creation of applications using large language models (LLMs) through the use of multiple conversational agents. These agents can be tailored to users' needs and are capable of interaction in multiple modes, including with other LLMs, human input, and additional tools. With AutoGen, developers have the flexibility to program agent interactions using both natural language and code, enabling the creation of complex patterns suitable for a wide range of applications. The framework has been proven effective across various fields, such as math, coding, question answering, and entertainment, based on empirical studies conducted to test its capabilities.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -403,19 +418,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"But that's unstructured. So let's teach the agent our preference."
"But that's unstructured. So let's teach the agent our preference for a particular structure."
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Please summarize this abstract. \n",
"When I'm summarizing an abstract, I try to make the summary contain just three short bullet points: the title, the innovation, and the key empirical results.\n",
@ -426,11 +441,13 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"- Title: AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation\n",
"- Innovation: Open-source framework for creating customizable LLM applications through agent conversations, supporting various modes and interaction behaviors.\n",
"- Key Empirical Results: Demonstrated effectiveness across diverse application domains, including mathematics, coding, question answering, and more.\n",
"- Innovation: AutoGen, an open-source framework that supports building large language model (LLM) applications by enabling conversation among multiple customizable and conversable agents.\n",
"- Key Empirical Results: Demonstrated effectiveness across a variety of domains, including mathematics, coding, question answering, operations research, and entertainment.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -451,21 +468,19 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"That's much better, but will the teachable agent remember these preferences in the future, for a different paper? Let's start a new chat to find out!"
"That's much better, but will the teachable agent remember these preferences in the future, even for a different paper? Let's start a new chat to find out!"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[93m\n",
"REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n",
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Please summarize this abstract.\n",
"\n",
@ -474,19 +489,21 @@
"Artificial intelligence (AI) researchers have been developing and refining large language models (LLMs) that exhibit remarkable capabilities across a variety of domains and tasks, challenging our understanding of learning and cognition. The latest model developed by OpenAI, GPT-4, was trained using an unprecedented scale of compute and data. In this paper, we report on our investigation of an early version of GPT-4, when it was still in active development by OpenAI. We contend that (this early version of) GPT-4 is part of a new cohort of LLMs (along with ChatGPT and Google's PaLM for example) that exhibit more general intelligence than previous AI models. We discuss the rising capabilities and implications of these models. We demonstrate that, beyond its mastery of language, GPT-4 can solve novel and difficult tasks that span mathematics, coding, vision, medicine, law, psychology and more, without needing any special prompting. Moreover, in all of these tasks, GPT-4's performance is strikingly close to human-level performance, and often vastly surpasses prior models such as ChatGPT. Given the breadth and depth of GPT-4's capabilities, we believe that it could reasonably be viewed as an early (yet still incomplete) version of an artificial general intelligence (AGI) system. In our exploration of GPT-4, we put special emphasis on discovering its limitations, and we discuss the challenges ahead for advancing towards deeper and more comprehensive versions of AGI, including the possible need for pursuing a new paradigm that moves beyond next-word prediction. We conclude with reflections on societal influences of the recent technological leap and future research directions.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"- Title: Sparks of Artificial General Intelligence: Early experiments with GPT-4\n",
"- Innovation: GPT-4, an LLM with remarkable capabilities, demonstrates human-level performance across various domains, like math, coding, vision, medicine, law, and psychology.\n",
"- Key results: GPT-4 significantly surpasses prior models, suggesting it may be an early version of AGI; limitations and challenges toward deeper AGI are also discussed.\n",
"\n",
"- Innovation: An exploration of an early version of GPT-4, indicating it might represent a step towards artificial general intelligence (AGI) by demonstrating a broad set of skills across diverse tasks, closely approaching human-level performance, without the need for specialized prompting.\n",
"\n",
"- Key Empirical Results: GPT-4 shows advanced capabilities in tasks related to various fields such as mathematics, coding, vision, medicine, law, and psychology, surpassing previous models like ChatGPT, suggesting it could be seen as an early form of AGI. Challenges and future research directions towards achieving more complete AGI systems are also discussed.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"teachable_agent.learn_from_user_feedback()\n",
"\n",
"text = \"\"\"Please summarize this abstract.\n",
"\n",
"Sparks of Artificial General Intelligence: Early experiments with GPT-4\n",
@ -514,7 +531,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Consider the identity: \n",
"9 * 4 + 6 * 6 = 72\n",
@ -523,23 +540,43 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Step 1: Identify the current equation and the target value on the right-hand side.\n",
"Current equation: 9 * 4 + 6 * 6 = 72\n",
"Target value: 99\n",
"To solve this problem, we need to find a way to add exactly 27 (since 99 - 72 = 27) to the left hand side of the equation by modifying only one of the integers in the equation. \n",
"\n",
"Step 2: Determine what numbers can be changed on the left-hand side.\n",
"Possible changes: 9, 4, 6 (first occurrence), and 6 (second occurrence)\n",
"Let's start by outlining a plan:\n",
"\n",
"Step 3: Evaluate whether each change in the numbers on the left-hand side can result in the target value.\n",
"9 -> If changed, it would create a larger product.\n",
"4 -> If changed, the product would increase or decrease by more than the amount needed to reach 99.\n",
"6 (first occurrence) -> If increased to 7, then the equation would be 9 * 4 + 7 * 6 = 94, which is still not 99.\n",
"6 (second occurrence) -> If increased to 7, then the equation would be 9 * 4 + 6 * 7 = 99.\n",
"1. Get the current total of the equation, which we already know equals 72.\n",
"2. Identify the target total, which we want to be 99.\n",
"3. Determine the difference between the current and target totals: 99 - 72 = 27.\n",
"4. Inspect each of the four terms on the left (9, 4, 6, and 6), and see if changing one of them can bring us to our target total.\n",
"5. We need a change that will add 27 to the left side. This means thinking about the factors and multiples related to each term.\n",
"6. We'll consider the impact of increasing or decreasing each term and see which modification gets us to add 27 exactly.\n",
"\n",
"Step 4: Write down the solution.\n",
"The solution is: 9 * 4 + 6 * 7\n",
"Now let's proceed with the changes:\n",
"\n",
"- If we increase or decrease the first term '9', which gets multiplied by '4', we could change '9' to '9 + a', where 'a' is some integer. But since we'll multiply by '4', this change wouldn't give us an additional 27 unless 'a' was a fraction, which is not permitted as we are only modifying integers.\n",
"- The second term '4' is multiplied by '9'. If we changed '4' to '5', we would add '9' to the sum, which is not enough; changing '4' to '6' would mean adding '18', which is also not enough.\n",
"- If we change one of the sixes being multiplied by '6', then we could consider turning one '6' into '9'. This would add '3 * 6 = 18' to the sum, which is too small of an increment as well.\n",
"- The largest increase we can get is by changing the '6' that is being multiplied by itself, meaning '6 * 6'. If we change one '6' to '9', then we get '9 * 6', which increases the sum by '3 * 6' = 18, which doesn't achieve the desired 27.\n",
"- But if we take the approach of changing a '6' that multiplies the '9' to a '9', thus making it '9 * 9', we add '3 * 9' = 27, exactly what we need.\n",
"\n",
"Following this logic, we can modify the second '6' that is being multiplied by the '9'. So we change '9 * 4 + 6 * 6' to '9 * 4 + 9 * 6'.\n",
"\n",
"Calculating the new values:\n",
"\n",
"9 * 4 = 36\n",
"9 * 6 = 54\n",
"\n",
"Adding these together will indeed give us 99:\n",
"\n",
"36 + 54 = 90\n",
"\n",
"Thus, our solution is:\n",
"\n",
"The solution is: 9 * 4 + 9 * 6\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -570,7 +607,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Consider the identity: \n",
"9 * 4 + 6 * 6 = 72\n",
@ -591,40 +628,47 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Step 1: Identify the original number (E) and the final number (F) on the right-hand side.\n",
"E = 72\n",
"F = 99\n",
"Given the new set of instructions and the correction that according to a past memory, the solution is \"9 * 1 + 6 * 9\", let's follow the steps carefully to arrive at the correct modified equation.\n",
"\n",
"Step 2: Calculate the difference between the two, G = F - E.\n",
"G = 99 - 72 = 27\n",
"Here's the step-by-step plan following your advice:\n",
"\n",
"Step 3: Examine the numbers on the left one by one until finding one that divides evenly into G, where negative integers are allowed.\n",
"9 (divides evenly into 27)\n",
"4 (does not)\n",
"6 (first occurrence, does not)\n",
"6 (second occurrence, does not)\n",
"1. Let E denote the original number on the right.\n",
" E = 72\n",
" \n",
"2. Let F denote the final number on the right.\n",
" F = 99\n",
" \n",
"3. Calculate the difference between the two, G = F - E.\n",
" G = 99 - 72 = 27\n",
" \n",
"4. Examine the numbers on the left one by one until finding one that divides evenly into G, where negative integers are allowed.\n",
" We have the numbers 9, 4, and 6 on the left.\n",
" Among these, 9 divides evenly into 27, because 27 / 9 = 3, which is an integer.\n",
" \n",
"5. Calculate J = G / H, where H is the number we found divides evenly into G.\n",
" J = 27 / 9 = 3\n",
" \n",
"6. Verify that J is an integer, and that H * J = G.\n",
" J is indeed an integer, and 9 * 3 = 27, which equals G.\n",
" \n",
"7. Find the number on the left which is multiplied by H (9, in this case), and call it K.\n",
" The number multiplied by our H (which is 9) on the left side of the equation is 4. So K = 4.\n",
" \n",
"8. Change K to K + J.\n",
" K + J = 4 + 3 = 7\n",
" \n",
"9. Recompute the value on the left, and verify that it equals F.\n",
" Substitute K with K + J in the original equation:\n",
" 9 * K + 6 * 6 = 9 * 7 + 6 * 6 = 63 + 36 = 99\n",
" \n",
"The recomputed value on the left is indeed 99, which matches our F.\n",
"\n",
"Step 4: Calculate J = G / H. This is the number of times that H divides into G.\n",
"H = 9 (since 9 divides evenly into 27)\n",
"J = 27 / 9 = 3\n",
"So, following your advice and the verified plan:\n",
"\n",
"Step 5: Verify that J is an integer, and that H * J = G.\n",
"H * J = 9 * 3 = 27\n",
"J = 3 (which is an integer)\n",
"\n",
"Step 6: Find the number on the left which is multiplied by H, and call it K.\n",
"Number on the leftside of the equation that is multiplied by 9: 4\n",
"Therefore, K = 4\n",
"\n",
"Step 7: Change K to K + J.\n",
"K + J = 4 + 3 = 7\n",
"\n",
"Step 8: Recompute the value on the left, and verify that it equals F.\n",
"9 * 7 + 6 * 6 = 63 + 36 = 99 (which is equal to F)\n",
"\n",
"Step 9: Write down the solution.\n",
"The solution is: 9 * 7 + 6 * 6\n",
"\n",
"--------------------------------------------------------------------------------\n"
@ -668,9 +712,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[93m\n",
"REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER\u001b[0m\n",
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Consider the identity: \n",
"9 * 4 + 6 * 6 = 72\n",
@ -679,29 +721,45 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Let's follow the advice to modify one integer on the left hand side of the equation to make the right hand side become 99:\n",
"Let's apply the steps you've provided to solve the problem at hand:\n",
"\n",
"1. Let E denote the original number on the right: E = 72.\n",
"2. Let F denote the final number on the right: F = 99.\n",
"3. Calculate the difference between the two, G = F - E: G = 99 - 72 = 27.\n",
"4. Examine the numbers on the left one by one to find one that divides evenly into G: 9, 4, 6. Since 9 divides 27 evenly, we will work with 9 (H = 9).\n",
"5. Calculate J = G / H: J = 27 / 9 = 3.\n",
"6. Verify that J is an integer, and that H * J = G: J is an integer, and 9 * 3 = 27.\n",
"7. Find the number on the left which is multiplied by H (9), and call it K: K = 4.\n",
"8. Change K to K + J: K_new = 4 + 3 = 7.\n",
"9. Recompute the value on the left, and verify that it equals F: (9 * 7) + (6 * 6) = 63 + 36 = 99.\n",
"1. Let E denote the original number on the right, which is 72.\n",
"2. Let F denote the final number on the right, which we want to be 99.\n",
"3. Calculate the difference between the two, G = F - E, so G = 99 - 72.\n",
"4. G = 27.\n",
"\n",
"The solution is: 9 * 7 + 6 * 6\n",
"Now we need to examine the numbers on the left one by one until finding one that divides evenly into G. In our context, the numbers on the left to consider are the multipliers, which are 9 and 6 (from the expressions 9 * 4 and 6 * 6).\n",
"\n",
"5. We look for a number H among our multipliers that can divide 27 without a remainder.\n",
"\n",
"6 is a potential candidate for H (since 6 * 6 = 36 and is part of our equation), but cannot evenly divide 27. But 9 can divide 27 evenly. Therefore, H = 9.\n",
"\n",
"6. Calculate J = G / H. Hence, J = 27 / 9.\n",
"7. J = 3. This means that H * J = 27, which confirms that J is a valid integer and that 9 * 3 = 27.\n",
"8. The number on the left which is multiplied by 9 is 4, and we call it K. So, K = 4.\n",
"\n",
"Now, we need to adjust K to accommodate the difference of 27.\n",
"\n",
"9. Change K to K + J. Thus, K = 4 + 3.\n",
"10. K = 7.\n",
"\n",
"Now we substitute K back into the equation:\n",
"\n",
"The solution is: 9 * 7 + 6 * 6.\n",
"\n",
"Let's verify that when we recompute the left side, it equals F (99):\n",
"\n",
"9 * 7 + 6 * 6 = 63 + 36 = 99, which confirms our solution.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"teachable_agent.learn_from_user_feedback()\n",
"\n",
"text = \"\"\"Consider the identity: \n",
"9 * 4 + 6 * 6 = 72\n",
"Can you modify exactly one integer (and not more than that!) on the left hand side of the equation so the right hand side becomes 99?\n",
@ -726,7 +784,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33muser\u001b[0m (to teachableagent):\n",
"\u001b[33muser\u001b[0m (to teachable_agent):\n",
"\n",
"Consider the identity: \n",
"8 * 3 + 7 * 9 = 87\n",
@ -735,21 +793,35 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mteachableagent\u001b[0m (to user):\n",
"\u001b[31m\n",
">>>>>>>> USING AUTO REPLY...\u001b[0m\n",
"\u001b[33mteachable_agent\u001b[0m (to user):\n",
"\n",
"Let's follow the advice to modify one integer on the left hand side of the equation to make the right hand side become 59:\n",
"Let's apply the plan step-by-step to find the correct modification:\n",
"\n",
"1. Let E denote the original number on the right: E = 87.\n",
"2. Let F denote the final number on the right: F = 59.\n",
"3. Calculate the difference between the two, G = F - E: G = 59 - 87 = -28.\n",
"4. Examine the numbers on the left one by one to find one that divides evenly into G: 8, 3, 7, 9. Since 7 divides -28 evenly, we will work with 7 (H = 7).\n",
"5. Calculate J = G / H: J = -28 / 7 = -4.\n",
"6. Verify that J is an integer, and that H * J = G: J is an integer, and 7 * (-4) = -28.\n",
"7. Find the number on the left which is multiplied by H (7), and call it K: K = 9.\n",
"8. Change K to K + J: K_new = 9 + (-4) = 5.\n",
"9. Recompute the value on the left, and verify that it equals F: (8 * 3) + (7 * 5) = 24 + 35 = 59.\n",
"1. Let E denote the original number on the right, which is 87.\n",
"2. Let F denote the final number on the right, which we want to be 59.\n",
"3. Calculate the difference between the two, G = F - E, so G = 59 - 87.\n",
"4. G = -28, and we note that the result is negative, meaning we need to decrease the left-hand side of the identity to reach the final number.\n",
"\n",
"The solution is: 8 * 3 + 7 * 5\n",
"We will now look at the left side of the equation and determine which number can be modified to account for this difference:\n",
"\n",
"The left side has the multipliers 8 (from 8 * 3) and 7 (from 7 * 9). We need to find one that divides -28 evenly.\n",
"\n",
"5. First, let's try H = 8. We find that -28 divided by 8 gives J that is not an integer.\n",
"6. Next, Let's check H = 7. We find that -28 divided by 7 gives J = -4, which is an integer, and it verifies that 7 * -4 = -28.\n",
"\n",
"7. The number on the left which is multiplied by 7 is 9, and we call it K. So, K = 9.\n",
"\n",
"8. Change K to K + J. Therefore, K becomes 9 + (-4) = 5.\n",
"\n",
"Now we substitute K back into the equation and get the equation:\n",
"\n",
"The solution is: 8 * 3 + 7 * 5.\n",
"\n",
"Let's verify the equation gives the right hand side of 59:\n",
"\n",
"8 * 3 + 7 * 5 =24 + 35 = 59, which confirms our solution.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -781,7 +853,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.5"
}
},
"nbformat": 4,

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,6 @@
from autogen import UserProxyAgent, config_list_from_json
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
from autogen.agentchat.contrib.capabilities.teachability import Teachability
from autogen import ConversableAgent
import os
import sys
@ -16,38 +17,44 @@ except ImportError:
return x
verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
cache_seed = None # Use an int to seed the response cache. Use None to disable caching.
# Specify the model to use. GPT-3.5 is less reliable than GPT-4 at learning from user input.
filter_dict = {"model": ["gpt-4-1106-preview"]}
# filter_dict = {"model": ["gpt-3.5-turbo-1106"]}
# filter_dict = {"model": ["gpt-4-0613"]}
# filter_dict = {"model": ["gpt-3.5-turbo-0613"]}
filter_dict = {"model": ["gpt-4"]}
# filter_dict = {"model": ["gpt-4"]}
# filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
def create_teachable_agent(reset_db=False):
"""Instantiates a TeachableAgent using the settings from the top of this file."""
"""Instantiates a teachable agent using the settings from the top of this file."""
# Load LLM inference endpoints from an env variable or a file
# See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
# and OAI_CONFIG_LIST_sample
config_list = config_list_from_json(env_or_file=OAI_CONFIG_LIST, filter_dict=filter_dict, file_location=KEY_LOC)
teachable_agent = TeachableAgent(
name="teachableagent",
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": cache_seed},
teach_config={
"verbosity": verbosity,
"reset_db": reset_db,
"path_to_db_dir": "./tmp/interactive/teachable_agent_db",
"recall_threshold": recall_threshold,
},
# Start by instantiating any agent that inherits from ConversableAgent.
teachable_agent = ConversableAgent(
name="teachable_agent",
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": None}, # Disable caching.
)
# Instantiate the Teachability capability. Its parameters are all optional.
teachability = Teachability(
verbosity=0, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
reset_db=reset_db,
path_to_db_dir="./tmp/interactive/teachability_db",
recall_threshold=1.5, # Higher numbers allow more (but less relevant) memos to be recalled.
)
# Now add the Teachability capability to the agent.
teachability.add_to_agent(teachable_agent)
return teachable_agent
def interact_freely_with_user():
"""Starts a free-form chat between the user and TeachableAgent."""
"""Starts a free-form chat between the user and a teachable agent."""
# Create the agents.
print(colored("\nLoading previous memory (if any) from disk.", "light_cyan"))
@ -57,13 +64,7 @@ def interact_freely_with_user():
# Start the chat.
teachable_agent.initiate_chat(user, message="Greetings, I'm a teachable user assistant! What's on your mind today?")
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Wrap up.
teachable_agent.close_db()
if __name__ == "__main__":
"""Lets the user test TeachableAgent interactively."""
"""Lets the user test a teachable agent interactively."""
interact_freely_with_user()

View File

@ -11,7 +11,7 @@ from test_assistant_agent import OAI_CONFIG_LIST, KEY_LOC # noqa: E402
try:
from openai import OpenAI
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
from autogen.agentchat.contrib.capabilities.teachability import Teachability
except ImportError:
skip = True
else:
@ -25,38 +25,40 @@ except ImportError:
return x
# Set verbosity levels to maximize code coverage.
qa_verbosity = 0 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
skill_verbosity = 3 # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
assert_on_error = False # GPT-4 nearly always succeeds on these unit tests, but GPT-3.5 is a bit less reliable.
recall_threshold = 1.5 # Higher numbers allow more (but less relevant) memos to be recalled.
cache_seed = None # Use an int to seed the response cache. Use None to disable caching.
# Specify the model to use by uncommenting one of the following lines.
# filter_dict={"model": ["gpt-4-1106-preview"]}
# filter_dict={"model": ["gpt-4-0613"]}
# filter_dict={"model": ["gpt-3.5-turbo-1106"]}
# filter_dict={"model": ["gpt-3.5-turbo-0613"]}
# filter_dict={"model": ["gpt-4"]}
filter_dict = {"model": ["gpt-35-turbo-16k", "gpt-3.5-turbo-16k"]}
def create_teachable_agent(reset_db=False, verbosity=0):
"""Instantiates a TeachableAgent using the settings from the top of this file."""
"""Instantiates a teachable agent using the settings from the top of this file."""
# Load LLM inference endpoints from an env variable or a file
# See https://microsoft.github.io/autogen/docs/FAQ#set-your-api-endpoints
# and OAI_CONFIG_LIST_sample
config_list = config_list_from_json(env_or_file=OAI_CONFIG_LIST, filter_dict=filter_dict, file_location=KEY_LOC)
teachable_agent = TeachableAgent(
name="teachableagent",
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": cache_seed},
teach_config={
"verbosity": verbosity,
"reset_db": reset_db,
"path_to_db_dir": "./tmp/teachable_agent_db",
"recall_threshold": recall_threshold,
},
# Start by instantiating any agent that inherits from ConversableAgent.
teachable_agent = ConversableAgent(
name="teachable_agent",
llm_config={"config_list": config_list, "timeout": 120, "cache_seed": None}, # Disable caching.
)
return teachable_agent
# Instantiate the Teachability capability. Its parameters are all optional.
teachability = Teachability(
verbosity=verbosity,
reset_db=reset_db,
path_to_db_dir="./tmp/teachability_db",
recall_threshold=1.5, # Higher numbers allow more (but less relevant) memos to be recalled.
)
# Now add the Teachability capability to the agent.
teachability.add_to_agent(teachable_agent)
return teachable_agent, teachability
def check_agent_response(teachable_agent, user, correct_answer):
@ -64,8 +66,6 @@ def check_agent_response(teachable_agent, user, correct_answer):
agent_response = user.last_message(teachable_agent)["content"]
if correct_answer not in agent_response:
print(colored(f"\nTEST FAILED: EXPECTED ANSWER {correct_answer} NOT FOUND IN AGENT RESPONSE", "light_red"))
if assert_on_error:
assert correct_answer in agent_response
return 1
else:
print(colored(f"\nTEST PASSED: EXPECTED ANSWER {correct_answer} FOUND IN AGENT RESPONSE", "light_cyan"))
@ -76,13 +76,14 @@ def use_question_answer_phrasing():
"""Tests whether the teachable agent can answer a question after being taught the answer in a previous chat."""
print(colored("\nTEST QUESTION-ANSWER PHRASING", "light_cyan"))
num_errors, num_tests = 0, 0
teachable_agent = create_teachable_agent(
reset_db=True, verbosity=qa_verbosity
teachable_agent, teachability = create_teachable_agent(
reset_db=True,
verbosity=0, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
) # For a clean test, clear the agent's memory.
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
teachability.prepopulate_db()
# Ask the teachable agent to do something using terminology it doesn't understand.
user.initiate_chat(recipient=teachable_agent, message="What is the twist of 5 and 7?")
@ -95,9 +96,6 @@ def use_question_answer_phrasing():
num_errors += check_agent_response(teachable_agent, user, "23")
num_tests += 1
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="What's the twist of 8 and 3 and 2?")
@ -105,7 +103,6 @@ def use_question_answer_phrasing():
num_tests += 1
# Wrap up.
teachable_agent.close_db()
return num_errors, num_tests
@ -113,13 +110,14 @@ def use_task_advice_pair_phrasing():
"""Tests whether the teachable agent can demonstrate a new skill after being taught a task-advice pair in a previous chat."""
print(colored("\nTEST TASK-ADVICE PHRASING", "light_cyan"))
num_errors, num_tests = 0, 0
teachable_agent = create_teachable_agent(
reset_db=True, verbosity=skill_verbosity # For a clean test, clear the teachable agent's memory.
teachable_agent, teachability = create_teachable_agent(
reset_db=True, # For a clean test, clear the teachable agent's memory.
verbosity=3, # 0 for basic info, 1 to add memory operations, 2 for analyzer messages, 3 for memo lists.
)
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
teachability.prepopulate_db()
# Ask the teachable agent to do something, and provide some helpful advice.
user.initiate_chat(
@ -129,9 +127,6 @@ def use_task_advice_pair_phrasing():
num_errors += check_agent_response(teachable_agent, user, "23")
num_tests += 1
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and require the teachable agent to use its new knowledge.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="Please calculate the twist of 8 and 3 and 2.")
@ -139,7 +134,6 @@ def use_task_advice_pair_phrasing():
num_tests += 1
# Wrap up.
teachable_agent.close_db()
return num_errors, num_tests
@ -184,20 +178,17 @@ def test_teachability_accuracy():
num_trials = 10 # The expected probability of failure is about 0.3 on each trial.
for trial in range(num_trials):
teachable_agent = create_teachable_agent(
teachable_agent, teachability = create_teachable_agent(
reset_db=True, verbosity=0
) # For a clean test, clear the agent's memory.
user = ConversableAgent("user", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
# Prepopulate memory with a few arbitrary memos, just to make retrieval less trivial.
teachable_agent.prepopulate_db()
teachability.prepopulate_db()
# Tell the teachable agent something it wouldn't already know.
user.initiate_chat(recipient=teachable_agent, message="My favorite color is teal.")
# Let the teachable agent remember things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
# Now start a new chat to clear the context, and ask the teachable agent about the new information.
print(colored("\nSTARTING A NEW CHAT WITH EMPTY CONTEXT", "light_cyan"))
user.initiate_chat(recipient=teachable_agent, message="What's my favorite color?")
@ -205,9 +196,6 @@ def test_teachability_accuracy():
print(colored(f"\nTRIAL {trial + 1} OF {num_trials} FINISHED", "light_cyan"))
# Wrap up.
teachable_agent.close_db()
# Exit on the first success.
if num_errors == 0:
return

View File

@ -1,5 +1,5 @@
---
title: AutoGen's TeachableAgent
title: AutoGen's Teachable Agents
authors: rickyloynd-microsoft
tags: [LLM, teach]
---
@ -7,33 +7,35 @@ tags: [LLM, teach]
![Teachable Agent Architecture](img/teachable-arch.png)
**TL;DR:**
* We introduce **TeachableAgent** (which uses **TextAnalyzerAgent**) so that users can teach their LLM-based assistants new facts, preferences, and skills.
* We showcase examples of `TeachableAgent` learning and later recalling facts, preferences, and skills in subsequent chats.
* We introduce **Teachable Agents** so that users can teach their LLM-based assistants new facts, preferences, and skills.
* We showcase examples of teachable agents learning and later recalling facts, preferences, and skills in subsequent chats.
## Introduction
Conversational assistants based on LLMs can remember the current chat with the user, and can also demonstrate in-context learning of user teachings during the conversation. But the assistant's memories and learnings are lost once the chat is over, or when a single chat grows too long for the LLM to handle effectively. Then in subsequent chats the user is forced to repeat any necessary instructions over and over.
`TeachableAgent` addresses these limitations by persisting user teachings across chat boundaries in long-term memory implemented as a vector database. Memory is automatically saved to disk at the end of each chat, then loaded from disk at the start of the next. Instead of copying all of memory into the context window, which would eat up valuable space, individual memories (called memos) are retrieved into context as needed. This allows the user to teach frequently used facts and skills to the teachable agent just once, and have it recall them in later chats.
`Teachability` addresses these limitations by persisting user teachings across chat boundaries in long-term memory implemented as a vector database. Instead of copying all of memory into the context window, which would eat up valuable space, individual memories (called memos) are retrieved into context as needed. This allows the user to teach frequently used facts and skills to the teachable agent just once, and have it recall them in later chats.
In order to make effective decisions about memo storage and retrieval, `TeachableAgent` calls an instance of `TextAnalyzerAgent` (another AutoGen agent) to identify and reformulate text as needed for remembering facts, preferences, and skills. Note that this adds extra LLM calls involving a relatively small number of tokens, which can add a few seconds to the time a user waits for each response.
Any instantiated `agent` that inherits from `ConversableAgent` can be made teachable by instantiating a `Teachability` object and calling its `add_to_agent(agent)` method.
In order to make effective decisions about memo storage and retrieval, the `Teachability` object calls an instance of `TextAnalyzerAgent` (another AutoGen agent) to identify and reformulate text as needed for remembering facts, preferences, and skills. Note that this adds extra LLM calls involving a relatively small number of tokens, which can add a few seconds to the time a user waits for each response.
## Run It Yourself
AutoGen contains three code examples that use `TeachableAgent`.
AutoGen contains four code examples that use `Teachability`.
1. Run [chat_with_teachable_agent.py](https://github.com/microsoft/autogen/blob/main/test/agentchat/contrib/chat_with_teachable_agent.py) to converse with `TeachableAgent`.
1. Run [chat_with_teachable_agent.py](https://github.com/microsoft/autogen/blob/main/test/agentchat/contrib/chat_with_teachable_agent.py) to converse with a teachable agent.
2. Use the Jupyter notebook [agentchat_teachability.ipynb](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb) to step through examples discussed below.
2. Run [test_teachable_agent.py](https://github.com/microsoft/autogen/blob/main/test/agentchat/contrib/test_teachable_agent.py) for quick unit testing of a teachable agent.
3. Run [test_teachable_agent.py](https://github.com/microsoft/autogen/blob/main/test/agentchat/contrib/test_teachable_agent.py) for quick unit testing of `TeachableAgent`.
3. Use the Jupyter notebook [agentchat_teachability.ipynb](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb) to step through examples discussed below.
4. Use the Jupyter notebook [agentchat_teachable_oai_assistants.ipynb](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachable_oai_assistants.ipynb) to make arbitrary OpenAI Assistants teachable through `GPTAssistantAgent`.
## Basic Usage of TeachableAgent
## Basic Usage of Teachability
1. Install dependencies
Please install pyautogen with the [teachable] option before using TeachableAgent.
Please install pyautogen with the [teachable] option before using `Teachability`.
```bash
pip install "pyautogen[teachable]"
```
@ -41,7 +43,8 @@ pip install "pyautogen[teachable]"
2. Import agents
```python
from autogen import UserProxyAgent, config_list_from_json
from autogen.agentchat.contrib.teachable_agent import TeachableAgent
from autogen.agentchat.contrib.capabilities.teachability import Teachability
from autogen import ConversableAgent # As an example
```
3. Create llm_config
@ -56,71 +59,68 @@ llm_config={"config_list": config_list, "timeout": 120}
4. Create the agents
```python
teachable_agent = TeachableAgent(
name="teachableagent",
llm_config=llm_config,
teach_config={
"reset_db": False, # Use True to force-reset the memo DB, and False to use an existing DB.
"path_to_db_dir": "./tmp/interactive/teachable_agent_db" # Can be any path.
}
# Start by instantiating any agent that inherits from ConversableAgent, which we use directly here for simplicity.
teachable_agent = ConversableAgent(
name="teachable_agent", # The name can be anything.
llm_config=llm_config
)
# Instantiate a Teachability object. Its parameters are all optional.
teachability = Teachability(
reset_db=False, # Use True to force-reset the memo DB, and False to use an existing DB.
path_to_db_dir="./tmp/interactive/teachability_db" # Can be any path, but teachable agents in a group chat require unique paths.
)
# Now add teachability to the agent.
teachability.add_to_agent(teachable_agent)
# For this test, create a user proxy agent as usual.
user = UserProxyAgent("user", human_input_mode="ALWAYS")
```
5. Chat with `TeachableAgent`
5. Chat with the teachable agent
```python
# This function will return once the user types 'exit'.
teachable_agent.initiate_chat(user, message="Hi, I'm a teachable user assistant! What's on your mind?")
```
6. Update the database on disk
```python
# Before closing the app, let the teachable agent store things that should be learned from this chat.
teachable_agent.learn_from_user_feedback()
teachable_agent.close_db()
```
## Example 1 - Learning user info
A user can teach the agent facts about themselves.
(Note that due to their finetuning, LLMs can be reluctant to admit that they know personal information.)
```
CLEARING MEMORY
teachableagent (to user):
Loading previous memory (if any) from disk.
teachable_agent (to user):
Hi, I'm a teachable user assistant! What's on your mind?
Greetings, I'm a teachable user assistant! What's on your mind today?
--------------------------------------------------------------------------------
Provide feedback to teachableagent. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: My name is Ricky
user (to teachableagent):
Provide feedback to teachable_agent. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: My name is Ricky
user (to teachable_agent):
My name is Ricky
--------------------------------------------------------------------------------
teachableagent (to user):
Hello, Ricky! It's nice to meet you. What can I help you with today?
>>>>>>>> USING AUTO REPLY...
teachable_agent (to user):
Hello Ricky! How can I assist you today?
--------------------------------------------------------------------------------
Provide feedback to teachableagent. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: exit
REVIEWING CHAT FOR USER TEACHINGS TO REMEMBER
SAVING MEMORY TO DISK
Location = ./tmp/interactive/teachable_agent_db\uid_text_dict.pkl
Provide feedback to teachable_agent. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: exit
```
In a later conversation, the user can check whether the teachable agent remembers their name. (For readability, the user prompts and some logged notices are not repeated below.)
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
what's my name?
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
Hi Ricky! How can I help you today?
@ -133,44 +133,44 @@ Hi Ricky! How can I help you today?
A user can teach the agent more complex, related facts.
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
What is the Vicuna model?
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
I'm sorry, but I don't have information about the "Vicuna model" in my current database. Could you please provide some context, description, or details about this model so I can better assist you?
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Vicuna is a 13B-parameter language model released by Meta.
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
Thank you for providing the information. The Vicuna model is a large-scale language model developed by Meta. It consists of 13 billion parameters, which is what makes it able to understand and generate human-like text across various topics. Parameters in a neural network are the weights and biases assigned to each connection between the neurons, and having more parameters enables the model to learn and capture more complex patterns and relationships in data. As an AI language model, Vicuna can be utilized in various language-related tasks, such as text completion, translation, summarization, and more.
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
What is the Orca model?
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
The Orca model is a yet-to-be-released language model by Meta, referring to the Research Preview API release of Meta's advanced language model. While I don't have specific details about the Orca model at this time, it is expected that the model would have an even larger number of parameters to further enhance its capabilities compared to its predecessor, Vicuna.
As AI language models evolve, they generally become more efficient at understanding and generating human-like text. This makes them more effective in completing a wide range of tasks, including text completion, translation, summarization, and Q&A, as well as providing more accurate and coherent results in various applications.
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Orca is a 13B-parameter language model released by Microsoft. It outperforms Vicuna on most tasks.
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
Thank you for the clarification. Orca is a 13-billion-parameter language model developed by Microsoft. It is designed to outperform Vicuna, the language model released by Meta. With a large number of parameters, Orca is capable of understanding and generating human-like text across diverse subjects, making it highly useful in various applications.
@ -181,12 +181,12 @@ Then in a later chat the teachable agent can answer questions about the facts it
(Remember to first close the previous chat by typing 'exit'.)
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
How does the Vicuna model compare to the Orca model?
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
The Vicuna and Orca models are both 13B-parameter language models released by Meta and Microsoft, respectively. In terms of performance, Orca has been reported to outperform Vicuna on most tasks. However, without detailed information about specific tasks and benchmarks, it is difficult to provide a more comprehensive comparison. Generally speaking, both models are advanced language models that aim to provide high-quality natural language understanding and generation, but Orca appears to have an edge in terms of overall performance.
@ -202,7 +202,7 @@ Be aware that a message like the next one cannot be entered as a single message
Such messages can be entered in a Jupyter notebook, or through some UI layer like that of ChatGPT.
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Please summarize this abstract.
@ -211,7 +211,7 @@ Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang
AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
AutoGen is an open-source framework designed to enable developers to create LLM applications with multiple customizable agents that can converse with each other to complete tasks. These agents can operate using different combinations of LLMs, human inputs, and tools, allowing developers to define agent interaction behaviors flexibly. AutoGen supports programming flexible conversation patterns using both natural language and code, making it suitable for building diverse applications with varying complexities and LLM capacities. Its effectiveness has been demonstrated through empirical studies across various domains including mathematics, coding, operations research, decision-making, and entertainment.
@ -223,7 +223,7 @@ Qingyun Wu, Gagan Bansal, Jieyu Zhang, Yiran Wu, Beibin Li, Erkang Zhu, Li Jiang
AutoGen is an open-source framework that allows developers to build LLM applications via multiple agents that can converse with each other to accomplish tasks. AutoGen agents are customizable, conversable, and can operate in various modes that employ combinations of LLMs, human inputs, and tools. Using AutoGen, developers can also flexibly define agent interaction behaviors. Both natural language and computer code can be used to program flexible conversation patterns for different applications. AutoGen serves as a generic infrastructure to build diverse applications of various complexities and LLM capacities. Empirical studies demonstrate the effectiveness of the framework in many example applications, with domains ranging from mathematics, coding, question answering, operations research, online decision-making, entertainment, etc.
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
- Title: AutoGen: Enabling Next-Gen LLM Applications via Multi-Agent Conversation
- Innovation: Open-source framework for creating customizable LLM applications through agent conversations, supporting various modes and interaction behaviors.
@ -235,7 +235,7 @@ teachableagent (to user):
Then in later chats the teacher doesn't need to reiterate their detailed preferences.
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Please summarize this abstract.
@ -244,7 +244,7 @@ Sébastien Bubeck, Varun Chandrasekaran, Ronen Eldan, Johannes Gehrke, Eric Horv
Artificial intelligence (AI) researchers have been developing and refining large language models (LLMs) that exhibit remarkable capabilities across a variety of domains and tasks, challenging our understanding of learning and cognition. The latest model developed by OpenAI, GPT-4, was trained using an unprecedented scale of compute and data. In this paper, we report on our investigation of an early version of GPT-4, when it was still in active development by OpenAI. We contend that (this early version of) GPT-4 is part of a new cohort of LLMs (along with ChatGPT and Google's PaLM for example) that exhibit more general intelligence than previous AI models. We discuss the rising capabilities and implications of these models. We demonstrate that, beyond its mastery of language, GPT-4 can solve novel and difficult tasks that span mathematics, coding, vision, medicine, law, psychology and more, without needing any special prompting. Moreover, in all of these tasks, GPT-4's performance is strikingly close to human-level performance, and often vastly surpasses prior models such as ChatGPT. Given the breadth and depth of GPT-4's capabilities, we believe that it could reasonably be viewed as an early (yet still incomplete) version of an artificial general intelligence (AGI) system. In our exploration of GPT-4, we put special emphasis on discovering its limitations, and we discuss the challenges ahead for advancing towards deeper and more comprehensive versions of AGI, including the possible need for pursuing a new paradigm that moves beyond next-word prediction. We conclude with reflections on societal influences of the recent technological leap and future research directions.
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
- Title: Sparks of Artificial General Intelligence: Early experiments with GPT-4
- Innovation: GPT-4, an LLM with remarkable capabilities, demonstrates human-level performance across various domains, like math, coding, vision, medicine, law, and psychology.
@ -261,7 +261,7 @@ Users can extend the teachable agent's capabilities by teaching it new skills fo
The [Sparks of AGI](https://arxiv.org/abs/2303.12712) paper evaluated GPT-4 on math problems like the following, which it could only solve 32% of the time. We first show a failure case, then teach the agent a strategy which lifts GPT-4's success rate above 95%.
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Consider the identity:
9 * 4 + 6 * 6 = 72
@ -286,7 +286,7 @@ Step 4: Write down the solution.
The solution is: 9 * 4 + 6 * 7
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Consider the identity:
9 * 4 + 6 * 6 = 72
@ -306,7 +306,7 @@ Here's some advice that may help:
Finally, write down your solution as: "The solution is: A * B + C * D".
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
Step 1: Identify the original number (E) and the final number (F) on the right-hand side.
E = 72
@ -348,7 +348,7 @@ The solution is: 9 * 7 + 6 * 6
In a later chat the user doesn't need to repeat the detailed advice.
```
--------------------------------------------------------------------------------
user (to teachableagent):
user (to teachable_agent):
Consider the identity:
8 * 3 + 7 * 9 = 87
@ -356,7 +356,7 @@ Can you modify exactly one integer (and not more than that!) on the left hand si
-Let's think step-by-step, write down a plan, and then write down your solution as: "The solution is: A * B + C * D".
--------------------------------------------------------------------------------
teachableagent (to user):
teachable_agent (to user):
Let's follow the advice to modify one integer on the left hand side of the equation to make the right hand side become 59:
@ -377,15 +377,11 @@ The solution is: 8 * 3 + 7 * 5
## Planned improvements
- Instructions for making any AutoGen agent user-teachable.
- Examples of how to include `TeachableAgent` in group chats.
- Expansions of AutoGen's current coding-testing strengths.
- Teachability enhancements:
- Understanding user instructions distributed over multiple turns.
- Learning from the agent's own experience, to reduce dependence on explicit user instructions.
- Learning skills built on top of previously learned skills.
- Understanding user instructions distributed over multiple turns.
- Learning from the agent's own experience, to reduce dependence on explicit user teachings.
- Learning skills built on top of previously learned skills.
## Conclusion
`TeachableAgent` is still under active research and development. For any problems you find or improvements you have in mind, please join our discussions in this repo and on our [Discord channel](https://discord.gg/pAbnFJrkgZ). We look forward to seeing how you and the rest of the community can use and improve `TeachableAgent` and the other agents in AutoGen!
`Teachability` is still under active research and development. For any problems you find or improvements you have in mind, please join our discussions in this repo and on our [Discord channel](https://discord.gg/pAbnFJrkgZ). We look forward to seeing how you and the rest of the community can use and improve teachable agents in AutoGen!

View File

@ -47,6 +47,7 @@ Links to notebook examples:
1. **Agent Teaching and Learning**
- Teach Agents New Skills & Reuse via Automated Chat - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teaching.ipynb)
- Teach Agents New Facts, User Preferences and Skills Beyond Coding - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
- Teach OpenAI Assistants Through GPTAssistantAgent - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachable_oai_assistants.ipynb)
- Agent Optimizer: Train Agents in an Agentic Way - [View Notebook](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_agentoptimizer.ipynb)
1. **Multi-Agent Chat with OpenAI Assistants in the loop**

View File

@ -138,14 +138,14 @@ Example notebooks:
[Automated Code Generation and Question Answering with Qdrant based Retrieval Augmented Agents](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_qdrant_RetrieveChat.ipynb)
- #### TeachableAgent
- #### Teachability
To use TeachableAgent, please install AutoGen with the [teachable] option.
To use Teachability, please install AutoGen with the [teachable] option.
```bash
pip install "pyautogen[teachable]"
```
Example notebook: [Chatting with TeachableAgent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)
Example notebook: [Chatting with a teachable agent](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_teachability.ipynb)