autogen.agent -> autogen.agentchat (#1148)

* autogen.agent -> autogen.agentchat

* bug fix in portfolio

* notebook

* timeout

* timeout

* infer lang; close #1150

* timeout

* message context

* context handling

* add sender to generate_reply

* clean up the receive function

* move mathchat to contrib

* contrib

* last_message
This commit is contained in:
Chi Wang 2023-07-28 21:17:51 -07:00 committed by GitHub
parent c0580a835b
commit 7ddb171cd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 3317 additions and 3960 deletions

View File

@ -82,11 +82,11 @@ Use the following guides to get started with FLAML in .NET:
# perform inference for a test instance
response = oai.Completion.create(context=test_instance, **config)
```
- LLM-driven intelligent agents which can perform tasks autonomously or with human feedback, including tasks that require using tools via code.
- LLM-driven intelligent agents which can collaborately perform tasks autonomously or with human feedback, including tasks that require using tools via code.
```python
assistant = AssistantAgent("assistant")
user = UserProxyAgent("user", human_input_mode="TERMINATE")
assistant.receive("Draw a rocket and save to a file named 'rocket.svg'")
user_proxy = UserProxyAgent("user_proxy")
user_proxy.initiate_chat("Show me the YTD gain of 10 largest technology companies as of today.")
```
* With three lines of code, you can start using this economical and fast
AutoML engine as a [scikit-learn style estimator](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML).

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
class Agent:
@ -33,12 +33,18 @@ class Agent:
def reset(self):
"""(Abstract method) Reset the agent."""
def generate_reply(self, messages: List[Dict], default_reply: Union[str, Dict] = "") -> Union[str, Dict]:
def generate_reply(
self,
messages: Optional[List[Dict]] = None,
default_reply: Optional[Union[str, Dict]] = "",
sender: Optional["Agent"] = None,
) -> Union[str, Dict]:
"""(Abstract method) Generate a reply based on the received messages.
Args:
messages (list[dict]): a list of messages received.
default_reply (str or dict): the default reply if no other reply is generated.
sender: sender of an Agent instance.
Returns:
str or dict: the generated reply.
"""

View File

@ -15,7 +15,7 @@ class AssistantAgent(ResponsiveAgent):
DEFAULT_SYSTEM_MESSAGE = """You are a helpful AI assistant.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute. You must indicate the script type in the code block. The user cannot provide any other feedback or perform any other action beyond executing the code you suggest. The user can't modify your code. So do not suggest incomplete code which requires users to modify. Don't use a code block if it's not intended to be executed by the user.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time.
2. When you need to perform some task with code, use the code to perform the task and output the result. Finish the task smartly. Solve the task step by step if you need to.
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.

View File

@ -3,7 +3,8 @@ import os
from pydantic import BaseModel, Extra, root_validator
from typing import Any, Callable, Dict, List, Optional, Union
from time import sleep
from flaml.autogen.agent import UserProxyAgent
from flaml.autogen.agentchat import Agent, UserProxyAgent
from flaml.autogen.code_utils import UNKNOWN, extract_code, execute_code, infer_lang
from flaml.autogen.math_utils import get_answer
@ -96,16 +97,16 @@ def _is_termination_msg_mathchat(message):
return not contain_code and get_answer(message) is not None and get_answer(message) != ""
def _add_print_to_last_line(s):
def _add_print_to_last_line(code):
"""Add print() to the last line of a string."""
# 1. check if there is already a print statement
if "print(" in s:
return s
if "print(" in code:
return code
# 2. extract the last line, enclose it in print() and return the new string
lines = s.splitlines()
lines = code.splitlines()
last_line = lines[-1]
if "\t" in last_line or "=" in last_line:
return s
return code
if "=" in last_line:
last_line = "print(" + last_line.split(" = ")[0] + ")"
lines.append(last_line)
@ -115,9 +116,9 @@ def _add_print_to_last_line(s):
return "\n".join(lines)
def _remove_print(s):
def _remove_print(code):
"""remove all print statements from a string."""
lines = s.splitlines()
lines = code.splitlines()
lines = [line for line in lines if not line.startswith("print(")]
return "\n".join(lines)
@ -126,6 +127,7 @@ class MathUserProxyAgent(UserProxyAgent):
"""(Experimental) A MathChat agent that can handle math problems."""
MAX_CONSECUTIVE_AUTO_REPLY = 15 # maximum number of consecutive auto replies (subject to future change)
DEFAULT_REPLY = "Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)"
def __init__(
self,
@ -279,16 +281,21 @@ class MathUserProxyAgent(UserProxyAgent):
is_success = False
return output, is_success
def generate_reply(self, messages: List[Dict], default_reply: Union[str, Dict] = "") -> Union[str, Dict]:
def generate_reply(
self,
messages: Optional[List[Dict]] = None,
default_reply: Optional[Union[str, Dict]] = DEFAULT_REPLY,
sender: Optional["Agent"] = None,
) -> Union[str, Dict]:
"""Generate an auto reply."""
if messages is None:
messages = self._oai_conversations[sender.name]
message = messages[-1]
message = message.get("content", "")
code_blocks = extract_code(message)
if len(code_blocks) == 1 and code_blocks[0][0] == UNKNOWN:
# no code block is found, lang should be `UNKNOWN``
if default_reply == "":
default_reply = "Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)"
return default_reply
is_success, all_success = True, True
reply = ""

View File

@ -46,7 +46,7 @@ class ResponsiveAgent(Agent):
The dict can contain the following keys: "content", "role", "name", "function_call".
max_consecutive_auto_reply (int): the maximum number of consecutive auto replies.
default to None (no limit provided, class attribute MAX_CONSECUTIVE_AUTO_REPLY will be used as the limit in this case).
The limit only plays a role when human_input_mode is not "ALWAYS".
When set to 0, no auto reply will be generated.
human_input_mode (str): whether to ask for human inputs every time a message is received.
Possible values are "ALWAYS", "TERMINATE", "NEVER".
(1) When "ALWAYS", the agent prompts for human input every time a message is received.
@ -104,6 +104,27 @@ class ResponsiveAgent(Agent):
"""A dictionary of conversations from name to list of oai messages."""
return self._oai_conversations
def last_message(self, agent: Optional[Agent] = None) -> Dict:
"""The last message exchanged with the agent.
Args:
agent (Agent): The agent in the conversation.
If None and more than one agent's conversations are found, an error will be raised.
If None and only one conversation is found, the last message of the only conversation will be returned.
Returns:
The last message exchanged with the agent.
"""
if agent is None:
n_conversations = len(self._oai_conversations)
if n_conversations == 0:
return None
if n_conversations == 1:
for conversation in self._oai_conversations.values():
return conversation[-1]
raise ValueError("More than one conversation is found. Please specify the sender to get the last message.")
return self._oai_conversations[agent.name][-1]
@property
def use_docker(self) -> Union[bool, str, None]:
"""Bool value of whether to use docker to execute the code,
@ -147,11 +168,43 @@ class ResponsiveAgent(Agent):
return True
def send(self, message: Union[Dict, str], recipient: "Agent"):
"""Send a message to another agent."""
# When the agent composes and sends the message, the role of the message is "assistant". (If 'role' exists and is 'function', it will remain unchanged.)
"""Send a message to another agent.
Args:
message (dict or str): message to be sent.
The message could contain the following fields (either content or function_call must be provided):
- content (str): the content of the message.
- function_call (str): the name of the function to be called.
- name (str): the name of the function to be called.
- role (str): the role of the message, any role that is not "function"
will be modified to "assistant".
- context (dict): the context of the message, which will be passed to
[oai.Completion.create](../oai/Completion#create).
For example, one agent can send a message A as:
```python
{
"content": "{use_tool_msg}",
"context": {
"use_tool_msg": "Use tool X if they are relevant."
}
}
```
Next time, one agent can send a message B with a different "use_tool_msg".
Then the content of message A will be refreshed to the new "use_tool_msg".
So effectively, this provides a way for an agent to send a "link" and modify
the content of the "link" later.
recipient (Agent): the recipient of the message.
Raises:
ValueError: if the message is not a valid oai message.
"""
# When the agent composes and sends the message, the role of the message is "assistant"
# unless it's "function".
valid = self._append_oai_message(message, "assistant", recipient.name)
if valid:
recipient.receive(message, self)
else:
raise ValueError("Message is not a valid oai message. Either content or function_call must be provided.")
def _print_received_message(self, message: Union[Dict, str], sender: "Agent"):
# print the message received
@ -200,35 +253,63 @@ class ResponsiveAgent(Agent):
# default reply is empty (i.e., no reply, in this case we will try to generate auto reply)
reply = ""
no_human_input_msg = ""
if self.human_input_mode == "ALWAYS":
reply = self.get_human_input(
"Provide feedback to the sender. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
f"Provide feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to end the conversation: "
)
elif self._consecutive_auto_reply_counter[
sender.name
] >= self.max_consecutive_auto_reply or self._is_termination_msg(message):
if self.human_input_mode == "TERMINATE":
reply = self.get_human_input(
"Please give feedback to the sender. (Press enter or type 'exit' to stop the conversation): "
)
reply = reply if reply else "exit"
else:
# this corresponds to the case when self._human_input_mode == "NEVER"
reply = "exit"
if reply == "exit" or (self._is_termination_msg(message) and not reply):
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
reply = reply if reply or not self._is_termination_msg(message) else "exit"
else:
if self._consecutive_auto_reply_counter[sender.name] >= self.max_consecutive_auto_reply:
if self.human_input_mode == "NEVER":
reply = "exit"
else:
# self.human_input_mode == "TERMINATE":
terminate = self._is_termination_msg(message)
reply = self.get_human_input(
f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
if terminate
else f"Please give feedback to {sender.name}. Press enter to skip and use auto-reply, or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
reply = reply if reply or not terminate else "exit"
elif self._is_termination_msg(message):
if self.human_input_mode == "NEVER":
reply = "exit"
else:
# self.human_input_mode == "TERMINATE":
reply = self.get_human_input(
f"Please give feedback to {sender.name}. Press enter or type 'exit' to stop the conversation: "
)
no_human_input_msg = "NO HUMAN INPUT RECEIVED." if not reply else ""
# if the human input is empty, and the message is a termination message, then we will terminate the conversation
reply = reply or "exit"
# print the no_human_input_msg
if no_human_input_msg:
print(f"\n>>>>>>>> {no_human_input_msg}", flush=True)
# stop the conversation
if reply == "exit":
# reset the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender.name] = 0
return
if reply:
# send the human reply
if reply or self.max_consecutive_auto_reply == 0:
# reset the consecutive_auto_reply_counter
self._consecutive_auto_reply_counter[sender.name] = 0
self.send(reply, sender)
return
# send the auto reply
self._consecutive_auto_reply_counter[sender.name] += 1
if self.human_input_mode != "NEVER":
print("\n>>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...", flush=True)
self.send(self.generate_reply(self._oai_conversations[sender.name], default_reply=reply), sender)
print("\n>>>>>>>> USING AUTO REPLY...", flush=True)
self.send(self.generate_reply(sender=sender), sender)
def reset(self):
"""Reset the agent."""
@ -237,23 +318,35 @@ class ResponsiveAgent(Agent):
def _oai_reply(self, messages: List[Dict]) -> Union[str, Dict]:
# TODO: #1143 handle token limit exceeded error
response = oai.ChatCompletion.create(messages=self._oai_system_message + messages, **self.oai_config)
response = oai.ChatCompletion.create(
context=messages[-1].get("context"), messages=self._oai_system_message + messages, **self.oai_config
)
return oai.ChatCompletion.extract_text_or_function_call(response)[0]
def generate_reply(self, messages: List[Dict], default_reply: Union[str, Dict] = "") -> Union[str, Dict]:
def generate_reply(
self,
messages: Optional[List[Dict]] = None,
default_reply: Optional[Union[str, Dict]] = "",
sender: Optional["Agent"] = None,
) -> Union[str, Dict]:
"""Reply based on the conversation history.
First, execute function or code and return the result.
AI replies are generated only when no code execution is performed.
Subclasses can override this method to customize the reply.
Either messages or sender must be provided.
Args:
messages: a list of messages in the conversation history.
default_reply (str or dict): default reply.
sender: sender of an Agent instance.
Returns:
str or dict: reply.
"""
assert messages is not None or sender is not None, "Either messages or sender must be provided."
if messages is None:
messages = self._oai_conversations[sender.name]
message = messages[-1]
if "function_call" in message:
_, func_return = self.execute_function(message["function_call"])

View File

@ -7,6 +7,7 @@ from typing import List, Dict, Tuple, Optional, Union, Callable
import re
import time
from hashlib import md5
import logging
from flaml.autogen import oai, DEFAULT_MODEL, FAST_MODEL
# Regular expression for finding a code block
@ -21,7 +22,7 @@ def infer_lang(code):
"""infer the language for the code.
TODO: make it robust.
"""
if code.startswith("python ") or code.startswith("pip"):
if code.startswith("python ") or code.startswith("pip") or code.startswith("python3 "):
return "sh"
return "python"
@ -148,6 +149,7 @@ def execute_code(
If None, the code from the file specified by filename will be executed.
Either code or filename must be provided.
timeout (Optional, int): The maximum execution time in seconds.
If None, a default timeout will be used. The default timeout is 600 seconds. On Windows, the timeout is not enforced when use_docker=False.
filename (Optional, str): The file name to save the code or where the code is stored when `code` is None.
If None, a file with a randomly generated name will be created.
The randomly generated file will be deleted after execution.
@ -158,7 +160,7 @@ def execute_code(
"path_to_flaml/autogen".
use_docker (Optional, list, str or bool): The docker image to use for code execution.
If a list or a str of image name(s) is provided, the code will be executed in a docker container
with the first image successfully pulled.
with the first image successfully pulled.
If None, False or empty, the code will be executed in the current environment.
Default is True, which will be converted into a list.
If the code is executed in the current environment,
@ -190,20 +192,28 @@ def execute_code(
if not use_docker or in_docker_container:
# already running in a docker container
cmd = [sys.executable if lang.startswith("python") else _cmd(lang), filename]
signal.signal(signal.SIGALRM, timeout_handler)
try:
signal.alarm(timeout)
# run the code in a subprocess in the current docker container in the working directory
if sys.platform == "win32":
logging.warning("SIGALRM is not supported on Windows. No timeout will be enforced.")
result = subprocess.run(
cmd,
cwd=work_dir,
capture_output=True,
)
signal.alarm(0)
except TimeoutError:
if original_filename is None:
os.remove(filepath)
return 1, TIMEOUT_MSG, None
else:
signal.signal(signal.SIGALRM, timeout_handler)
try:
signal.alarm(timeout)
# run the code in a subprocess in the current docker container in the working directory
result = subprocess.run(
cmd,
cwd=work_dir,
capture_output=True,
)
signal.alarm(0)
except TimeoutError:
if original_filename is None:
os.remove(filepath)
return 1, TIMEOUT_MSG, None
if original_filename is None:
os.remove(filepath)
return result.returncode, result.stderr if result.returncode else result.stdout, None

View File

@ -1,10 +1,10 @@
from time import sleep
import logging
import numpy as np
import time
from typing import List, Optional, Dict, Callable, Any
from typing import List, Optional, Dict, Callable, Union
import sys
import shutil
import numpy as np
from flaml import tune, BlendSearch
from flaml.tune.space import is_constant
from flaml.automl.logger import logger_formatter
@ -792,15 +792,15 @@ class Completion(openai_Completion):
return cls._get_response(params, raise_on_ratelimit_or_timeout=raise_on_ratelimit_or_timeout)
@classmethod
def _instantiate(cls, template: str, context: Optional[Dict] = None):
if not context:
def _instantiate(cls, template: Union[str, None], context: Optional[Dict] = None):
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context)
return template(context)
@classmethod
def _construct_params(cls, data_instance, config, prompt=None, messages=None):
def _construct_params(cls, context, config, prompt=None, messages=None):
params = config.copy()
model = config["model"]
prompt = config.get("prompt") if prompt is None else prompt
@ -814,12 +814,14 @@ class Completion(openai_Completion):
params["messages"] = (
[
{
"role": m["role"],
"content": cls._instantiate(m["content"], data_instance),
**m,
"content": cls._instantiate(m["content"], context),
}
if m.get("content")
else m
for m in messages
]
if data_instance
if context
else messages
)
elif model in cls.chat_models or issubclass(cls, ChatCompletion):
@ -827,12 +829,12 @@ class Completion(openai_Completion):
params["messages"] = [
{
"role": "user",
"content": cls._instantiate(prompt, data_instance),
"content": cls._instantiate(prompt, context),
},
]
params.pop("prompt", None)
else:
params["prompt"] = cls._instantiate(prompt, data_instance)
params["prompt"] = cls._instantiate(prompt, context)
return params
@classmethod

View File

@ -60,6 +60,11 @@ try:
except ImportError:
resource = None
try:
from lightgbm import LGBMClassifier, LGBMRegressor, LGBMRanker
except ImportError:
LGBMClassifier = LGBMRegressor = LGBMRanker = None
logger = logging.getLogger("flaml.automl")
# FREE_MEM_RATIO = 0.2
@ -1300,17 +1305,10 @@ class LGBMEstimator(BaseEstimator):
self.params["verbose"] = -1
if self._task.is_classification():
from lightgbm import LGBMClassifier
self.estimator_class = LGBMClassifier
elif task == "rank":
from lightgbm import LGBMRanker
self.estimator_class = LGBMRanker
else:
from lightgbm import LGBMRegressor
self.estimator_class = LGBMRegressor
self._time_per_iter = None

View File

@ -38,10 +38,10 @@ def flamlize_estimator(super_class, name: str, task: str, alternatives=None):
self._params = params
super().__init__(**params)
@classmethod
@wraps(super_class._get_param_names)
def _get_param_names(cls):
return super_class._get_param_names()
# @classmethod
# @wraps(super_class._get_param_names)
# def _get_param_names(cls):
# return super_class._get_param_names() if hasattr(super_class, "_get_param_names") else []
def suggest_hyperparams(self, X, y):
"""Suggest hyperparameters.
@ -179,3 +179,6 @@ else:
"classification",
[("max_depth", 0, "xgboost")],
)
# if hasattr(xgboost.XGBRegressor, "_get_param_names"):
# XGBRegressor._get_param_names = xgboost.XGBRegressor._get_param_names
# XGBClassifier._get_param_names = xgboost.XGBClassifier._get_param_names

View File

@ -42,7 +42,7 @@ def config_predictor_tuple(tasks, configs, meta_features, regret_matrix):
.apply(lambda row: row.apply(lambda x: (x, row.name)), axis=1)
)
print(regret)
preferences = np.argsort(regret, axis=0)
preferences = pd.DataFrame(np.argsort(regret, axis=0), columns=regret.columns)
print(preferences)
return (meta_features_norm, preferences, proc)
@ -119,12 +119,11 @@ def serialize(configs, regret, meta_features, output_file, config_path):
"portfolio": portfolio,
"preprocessing": proc,
"neighbors": [
{"features": tuple(x), "choice": _filter(preferences[y], regret[y])}
{"features": x.tolist(), "choice": _filter(preferences[y], regret[y])}
for x, y in zip(meta_features_norm.to_records(index=False), preferences.columns)
],
"configsource": list(configs),
}
with open(output_file, "w+") as f:
json.dump(meta_predictor, f, indent=4)
return meta_predictor

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -1,458 +0,0 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "ae1f50ec",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agent_function_call.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "9a71fa36",
"metadata": {},
"source": [
"# Interative LLM Agent with Function Calls\n",
"\n",
"FLAML offers an experimental feature of interactive LLM agents, which can be used to solve various tasks with human or automatic feedback, including tasks that require using tools via code. Please find documentation about this feature [here](https://microsoft.github.io/FLAML/docs/Use-Cases/Auto-Generation#agents-experimental).\n",
"\n",
"In this notebook, we demonstrate how to use `AssistantAgent` and `UserProxyAgent` to make function calls with the new feature of OpenAI models (in model version 0613). A specified prompt and function configs need to be passed to `AssistantAgent` to initialize the agent. The corresponding functions need to be passed to `UserProxyAgent`, which will be responsible for executing any function calls made by `AssistantAgent`. Besides this requirement of matching descriptions with functions, we recommend checking the system prompt to make sure the instructions align with the function call descriptions.\n",
"\n",
"## Requirements\n",
"\n",
"FLAML requires `Python>=3.8`. To run this notebook example, please install flaml with the [mathchat] option since we will import functions from `MathUserProxyAgent`:\n",
"```bash\n",
"pip install flaml[mathchat]\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2b803c17",
"metadata": {},
"outputs": [],
"source": [
"# %pip install flaml[mathchat]~=2.0.0rc4"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "5ebd2397",
"metadata": {},
"source": [
"## Set your API Endpoint\n",
"\n",
"The [`config_list_from_models`](https://microsoft.github.io/FLAML/docs/reference/autogen/oai/openai_utils#config_list_from_models) function tries to create a list of configurations using Azure OpenAI endpoints and OpenAI endpoints for the provided list of models. It assumes the api keys and api bases are stored in the corresponding environment variables or local txt files:\n",
"\n",
"- OpenAI API key: os.environ[\"OPENAI_API_KEY\"] or `openai_api_key_file=\"key_openai.txt\"`.\n",
"- Azure OpenAI API key: os.environ[\"AZURE_OPENAI_API_KEY\"] or `aoai_api_key_file=\"key_aoai.txt\"`. Multiple keys can be stored, one per line.\n",
"- Azure OpenAI API base: os.environ[\"AZURE_OPENAI_API_BASE\"] or `aoai_api_base_file=\"base_aoai.txt\"`. Multiple bases can be stored, one per line.\n",
"\n",
"It's OK to have only the OpenAI API key, or only the Azure OpenAI API key + base.\n",
"If you open this notebook in google colab, you can upload your files by click the file icon on the left panel and then choose \"upload file\" icon.\n",
"\n",
"The following code excludes Azure OpenAI endpoints from the config list because they don't support functions yet. Remove the `exclude` argument after they do."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dca301a4",
"metadata": {},
"outputs": [],
"source": [
"from flaml import oai\n",
"\n",
"config_list = oai.config_list_from_models(model_list=[\"gpt-4\", \"gpt-3.5-turbo\", \"gpt-3.5-turbo-16k\"], exclude=\"aoai\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "92fde41f",
"metadata": {},
"source": [
"The config list looks like the following:\n",
"```python\n",
"config_list = [\n",
" {\n",
" 'model': 'gpt-4',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" }, # OpenAI API endpoint for gpt-4\n",
" {\n",
" 'model': 'gpt-3.5-turbo',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" }, # OpenAI API endpoint for gpt-3.5-turbo\n",
" {\n",
" 'model': 'gpt-3.5-turbo-16k',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" }, # OpenAI API endpoint for gpt-3.5-turbo-16k\n",
"]\n",
"```\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2b9526e7",
"metadata": {},
"source": [
"## Making Function Calls\n",
"\n",
"In this example, we demonstrate function call execution with `AssistantAgent` and `UserProxyAgent`. With the default system prompt of `AssistantAgent`, we allow the LLM assistant to perform tasks with code, and the `UserProxyAgent` would extract code blocks from the LLM response and execute them. With the new \"function_call\" feature, we define a new function using the pre-defined `execute_code` from `UserProxyAgent` and specify the description of the function in the OpenAI config. \n",
"\n",
"Then, the model has two paths to execute code:\n",
"1. Put the code blocks in the response. `UserProxyAgent` will extract and execute the code through `execute_code` method in the class.\n",
"2. As we put a function description to OpenAI config and register a function `exec_code` in `UserProxyAgent`, the model can also make function calls (will be put in `function_call` field of the API reply). `UserProxyAgent` will execute the function call through the registered `exec_code` method."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9fb85afb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"user (to chatbot):\n",
"\n",
"Draw a rocket and save to a file named 'rocket.svg'\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"You can use the `svgwrite` library for Python to draw images into SVG format. Before we draw the rocket, you need to install the library. Use the code below to install it.\n",
"\n",
"```sh\n",
"pip install svgwrite\n",
"```\n",
"\n",
"After installing the library, here is the python code you can use to draw a rocket and save it to a file named 'rocket.svg':\n",
"\n",
"```python\n",
"# filename: draw_rocket.py\n",
"\n",
"import svgwrite\n",
"\n",
"def draw_rocket():\n",
" dwg = svgwrite.Drawing('rocket.svg', profile='tiny')\n",
"\n",
" # Draw rocket body\n",
" dwg.add(dwg.rect((50, 20), (20, 40), fill='grey'))\n",
"\n",
" # Draw top of rocket\n",
" dwg.add(dwg.polygon(points=[(50, 20), (60, 0), (70, 20)], fill='red'))\n",
"\n",
" # Draw bottom of rocket\n",
" dwg.add(dwg.polygon(points=[(50, 60), (60, 80), (70, 60)], fill='red'))\n",
"\n",
" # Draw rocket window\n",
" dwg.add(dwg.circle(center=(60, 40), r=5, fill='blue'))\n",
"\n",
" dwg.save()\n",
"\n",
"draw_rocket()\n",
"```\n",
"You can run this code using Python by calling `python draw_rocket.py`. After running this script, you will have a file named `rocket.svg` in your current directory. The SVG picture represents a simple rocket with a gray body, red top and bottom, and a blue window. \n",
"\n",
"Please replace the `draw_rocket.py` with your actual python filename when you execute the script.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"user (to chatbot):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Collecting svgwrite\n",
" Downloading svgwrite-1.4.3-py3-none-any.whl (67 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.1/67.1 kB 3.2 MB/s eta 0:00:00\n",
"Installing collected packages: svgwrite\n",
"Successfully installed svgwrite-1.4.3\n",
"WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"***** Suggested function Call: execute_code *****\n",
"Arguments: \n",
"{\n",
" \"code_type\": \"python\",\n",
" \"code\": \"import svgwrite\\n\\ndef draw_rocket():\\n dwg = svgwrite.Drawing('rocket.svg', profile='tiny')\\n\\n # Draw rocket body\\n dwg.add(dwg.rect((50, 20), (20, 40), fill='grey'))\\n\\n # Draw top of rocket\\n dwg.add(dwg.polygon(points=[(50, 20), (60, 0), (70, 20)], fill='red'))\\n\\n # Draw bottom of rocket\\n dwg.add(dwg.polygon(points=[(50, 60), (60, 80), (70, 60)], fill='red'))\\n\\n # Draw rocket window\\n dwg.add(dwg.circle(center=(60, 40), r=5, fill='blue'))\\n\\n dwg.save()\\n\\ndraw_rocket()\"\n",
"}\n",
"*************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to chatbot):\n",
"\n",
"***** Response from calling function \"execute_code\" *****\n",
"(0, '\\n')\n",
"*********************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"Great! The code to draw the rocket and save it to 'rocket.svg' should have successfully executed. \n",
"\n",
"You should now be able to find the file 'rocket.svg' in your current directory and open it with an application that can handle SVG files (a web browser, for instance).\n",
"\n",
"Is there anything else you need help with? If not, please reply 'TERMINATE'.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to chatbot):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"from flaml.autogen.agent import AssistantAgent, UserProxyAgent\n",
"\n",
"oai_config = {\n",
" \"functions\": [\n",
" {\n",
" \"name\": \"execute_code\",\n",
" \"description\": \"Receive a python code or shell script and return the execution result.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"code_type\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Code type, 'python' or 'sh'.\",\n",
" },\n",
" \"code\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Valid Python code to execute.\",\n",
" }\n",
" },\n",
" \"required\": [\"code_type\", \"code\"],\n",
" },\n",
" }\n",
" ],\n",
" \"config_list\": config_list,\n",
"}\n",
"chatbot = AssistantAgent(\"chatbot\", oai_config=oai_config)\n",
"\n",
"# create a UserProxyAgent instance named \"user\"\n",
"user = UserProxyAgent(\n",
" \"user\",\n",
" human_input_mode=\"NEVER\",\n",
" code_execution_config={\"work_dir\": \"coding\"},\n",
")\n",
"\n",
"# define an `execute_code` function according to the function desription\n",
"def exec_code(code_type, code):\n",
" # here we reuse the method in the user proxy agent\n",
" # in general, this is not necessary\n",
" return user.execute_code_blocks([(code_type, code)])\n",
"\n",
"# register the `execute_code` function\n",
"user.register_function(function_map={\"execute_code\": exec_code})\n",
"\n",
"# start the conversation\n",
"user.initiate_chat(\n",
" chatbot,\n",
" message=\"Draw a rocket and save to a file named 'rocket.svg'\",\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "42cee331",
"metadata": {},
"outputs": [
{
"data": {
"image/svg+xml": [
"<svg xmlns=\"http://www.w3.org/2000/svg\" xmlns:ev=\"http://www.w3.org/2001/xml-events\" xmlns:xlink=\"http://www.w3.org/1999/xlink\" baseProfile=\"tiny\" height=\"100%\" version=\"1.2\" width=\"100%\"><defs/><rect fill=\"grey\" height=\"40\" width=\"20\" x=\"50\" y=\"20\"/><polygon fill=\"red\" points=\"50,20 60,0 70,20\"/><polygon fill=\"red\" points=\"50,60 60,80 70,60\"/><circle cx=\"60\" cy=\"40\" fill=\"blue\" r=\"5\"/></svg>"
],
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# uncomment the following to render the svg file\n",
"# from IPython.display import SVG, display\n",
"\n",
"# display(SVG(\"coding/rocket.svg\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e9531d55",
"metadata": {},
"source": [
"## Another example with Wolfram Alpha API\n",
"\n",
"We give another example of query Wolfram Alpha API to solve math problem. We use the predefined function `MathUserProxyAgent().execute_one_wolfram_query` as the function to be called."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4a917492",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"user (to chatbot):\n",
"\n",
"Problem: Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"***** Suggested function Call: query_wolfram *****\n",
"Arguments: \n",
"{\n",
" \"query\": \"solve (2x+10)(x+3) < (3x+9)(x+8) for x\"\n",
"}\n",
"**************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to chatbot):\n",
"\n",
"***** Response from calling function \"query_wolfram\" *****\n",
"('Assumption: solve (2 x + 10) (x + 3)<(3 x + 9) (x + 8) for x \\nAnswer: ans 0: x<-14\\nans 1: x>-3\\n', True)\n",
"**********************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"The solution to the inequality $(2x+10)(x+3)<(3x+9)(x+8)$ is $x \\in (-\\infty, -14) \\cup (-3, +\\infty)$.\n",
"\n",
"To express in interval notation, the answer is $(-\\infty, -14) \\cup (-3, \\infty)$. \n",
"\n",
"TERMINATE.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to chatbot):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"chatbot (to user):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"import os\n",
"from flaml.autogen.agent import AssistantAgent, UserProxyAgent\n",
"from flaml.autogen.agent.math_user_proxy_agent import MathUserProxyAgent\n",
"\n",
"# you need to provide a wolfram alpha appid to run this example\n",
"if not os.environ.get(\"WOLFRAM_ALPHA_APPID\"):\n",
" os.environ[\"WOLFRAM_ALPHA_APPID\"] = open(\"wolfram.txt\").read().strip()\n",
"\n",
"\n",
"sys_prompt = \"\"\"You are an advanced AI with the capability to solve complex math problems.\n",
"Wolfram alpha is provided as an external service to help you solve math problems.\n",
"\n",
"When the user gives a math problem, please use the most efficient way to solve the problem.\n",
"You are encouraged to use Wolfram alpha whenever it is possible during the solving process. For example, simplications, calculations, equation solving, etc.\n",
"However, if the operation requires little computation (very simple calculations, etc), you can also solve it directly.\n",
"Reply \"TERMINATE\" in the end when everything is done.\n",
"\"\"\"\n",
"oai_config = {\n",
" \"model\": \"gpt-4-0613\",\n",
" \"functions\": [\n",
" {\n",
" \"name\": \"query_wolfram\",\n",
" \"description\": \"Return the API query result from the Wolfram Alpha. the ruturn is a tuple of (result, is_success).\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"query\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The Wolfram Alpha code to be executed.\",\n",
" }\n",
" },\n",
" \"required\": [\"query\"],\n",
" },\n",
" }\n",
" ],\n",
" \"config_list\": config_list,\n",
"}\n",
"chatbot = AssistantAgent(\"chatbot\", system_message=sys_prompt, oai_config=oai_config)\n",
"\n",
"# the key in `function_map` should match the function name in \"functions\" above\n",
"# we register a class instance method directly\n",
"user = UserProxyAgent(\n",
" \"user\",\n",
" max_consecutive_auto_reply=2,\n",
" human_input_mode=\"NEVER\",\n",
" function_map={\"query_wolfram\": MathUserProxyAgent().execute_one_wolfram_query},\n",
")\n",
"\n",
"# start the conversation\n",
"user.initiate_chat(\n",
" chatbot,\n",
" message=\"Problem: Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\",\n",
")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "flaml_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,894 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agentchat_MathChat.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Auto Generated Agent Chat: Using MathChat to Solve Math Problems\n",
"\n",
"MathChat is a convesational framework for math problem solving. In this notebook, we demonstrate how to use MathChat to solve math problems. MathChat uses the `AssistantAgent` and `MathUserProxyAgent`, which is similar to the usage of `AssistantAgent` and `UserProxyAgent` in other notebooks (e.g., [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_auto_feedback_from_code_execution.ipynb)). Essentially, `MathUserProxyAgent` implements a different auto reply mechanism corresponding to the MathChat prompts. The original implementation and exeperiments of MathChat are in this [branch](https://github.com/kevin666aa/FLAML/tree/gpt_math_solver/flaml/autogen/math), and you can find more details in our paper [An Empirical Study on Challenging Math Problem Solving with GPT-4](https://arxiv.org/abs/2306.01337).\n",
"\n",
"## Requirements\n",
"\n",
"FLAML requires `Python>=3.8`. To run this notebook example, please install flaml with the [mathchat] option.\n",
"```bash\n",
"pip install flaml[mathchat]\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# %pip install flaml[mathchat]~=2.0.0rc4"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set your API Endpoint\n",
"\n",
"The [`config_list_from_json`](https://microsoft.github.io/FLAML/docs/reference/autogen/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from flaml import oai\n",
"\n",
"config_list = oai.config_list_from_json(\n",
" \"OAI_CONFIG_LIST\",\n",
" filter_dict={\n",
" \"model\": {\n",
" \"gpt-4\",\n",
" \"gpt4\",\n",
" \"gpt-4-32k\",\n",
" \"gpt-4-32k-0314\",\n",
" \"gpt-3.5-turbo\",\n",
" \"gpt-3.5-turbo-16k\",\n",
" \"gpt-3.5-turbo-0301\",\n",
" \"chatgpt-35-turbo-0301\",\n",
" \"gpt-35-turbo-v0301\",\n",
" }\n",
" }\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"It first looks for environment variable \"OAI_CONFIG_LIST\" which needs to be a valid json string. If that variable is not found, it then looks for a json file named \"OAI_CONFIG_LIST\". It filters the configs by models (you can filter by other keys as well). Only the gpt-4 and gpt-3.5-turbo models are kept in the list based on the filter condition.\n",
"\n",
"The config list looks like the following:\n",
"```python\n",
"config_list = [\n",
" {\n",
" 'model': 'gpt-4',\n",
" 'api_key': '<your OpenAI API key here>',\n",
" },\n",
" {\n",
" 'model': 'gpt-4',\n",
" 'api_key': '<your Azure OpenAI API key here>',\n",
" 'api_base': '<your Azure OpenAI API base here>',\n",
" 'api_type': 'azure',\n",
" 'api_version': '2023-06-01-preview',\n",
" },\n",
" {\n",
" 'model': 'gpt-3.5-turbo',\n",
" 'api_key': '<your Azure OpenAI API key here>',\n",
" 'api_base': '<your Azure OpenAI API base here>',\n",
" 'api_type': 'azure',\n",
" 'api_version': '2023-06-01-preview',\n",
" },\n",
"]\n",
"```\n",
"\n",
"If you open this notebook in colab, you can upload your files by clicking the file icon on the left panel and then choose \"upload file\" icon.\n",
"\n",
"You can set the value of config_list in other ways you prefer, e.g., loading from a YAML file."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Construct agents for MathChat\n",
"\n",
"We start by initialzing the `AssistantAgent` and `MathUserProxyAgent`. The system message needs to be set to \"You are a helpful assistant.\" for MathChat. The detailed instructions are given in the user message. Later we will use the `MathUserProxyAgent.generate_init_message` to combine the instructions and a math problem for an initial message to be sent to the LLM assistant."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from flaml.autogen.agentchat import AssistantAgent\n",
"from flaml.autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent\n",
"\n",
"oai.ChatCompletion.start_logging()\n",
"\n",
"# 1. create an AssistantAgent instance named \"assistant\"\n",
"assistant = AssistantAgent(\n",
" name=\"assistant\", \n",
" system_message=\"You are a helpful assistant.\",\n",
" oai_config={\n",
" \"request_timeout\": 600,\n",
" \"seed\": 42,\n",
" \"config_list\": config_list,\n",
" }\n",
")\n",
"\n",
"# 2. create the MathUserProxyAgent instance named \"mathproxyagent\"\n",
"# By default, the human_input_mode is \"NEVER\", which means the agent will not ask for human input.\n",
"mathproxyagent = MathUserProxyAgent(\n",
" name=\"mathproxyagent\", \n",
" human_input_mode=\"NEVER\",\n",
" code_execution_config={\"use_docker\": False},\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example 1\n",
"\n",
"Problem: Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\n",
"\n",
"Correct Solution: \n",
"We have \\begin{align*} (2x+10)(x+3)&<(3x+9)(x+8) \\quad \\Rightarrow\n",
"\\\\ 2(x+5)(x+3)&<3(x+3)(x+8) \\quad \\Rightarrow\n",
"\\\\ 2(x+5)(x+3)-3(x+3)(x+8)&<0 \\quad \\Rightarrow\n",
"\\\\ (2x+10-(3x+24))(x+3)&<0 \\quad \\Rightarrow\n",
"\\\\ (-x-14)(x+3)&<0 \\quad \\Rightarrow\n",
"\\\\ (x+14)(x+3)&>0.\n",
"\\end{align*} This inequality is satisfied if and only if $(x+14)$ and $(x+3)$ are either both positive or both negative. Both factors are positive for $x>-3$ and both factors are negative for $x<-14$. When $-14<x<-3$, one factor is positive and the other negative, so their product is negative. Therefore, the range of $x$ that satisfies the inequality is $ \\boxed{(-\\infty, -14)\\cup(-3,\\infty)} $."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mathproxyagent (to assistant):\n",
"\n",
"Let's use Python to solve a math problem.\n",
"\n",
"Query requirements:\n",
"You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
"You can use packages like sympy to help you.\n",
"You must follow the formats below to write your code:\n",
"```python\n",
"# your code\n",
"```\n",
"\n",
"First state the key idea to solve the problem. You may choose from three ways to solve the problem:\n",
"Case 1: If the problem can be solved with Python code directly, please write a program to solve it. You can enumerate all possible arrangements if needed.\n",
"Case 2: If the problem is mostly reasoning, you can solve it by yourself directly.\n",
"Case 3: If the problem cannot be handled in the above two ways, please follow this process:\n",
"1. Solve the problem step by step (do not over-divide the steps).\n",
"2. Take out any queries that can be asked through Python (for example, any calculations or equations that can be calculated).\n",
"3. Wait for me to give the results.\n",
"4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
"\n",
"After all the queries are run and you get the answer, put the answer in \\boxed{}.\n",
"\n",
"Problem:\n",
"Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"assistant (to mathproxyagent):\n",
"\n",
"This problem can be solved by first simplifying the inequality, finding the critical points, and then testing points from each interval defined by the critical points to find where the inequality holds true. We can use Python with the sympy package for these calculations. Here is how:\n",
"\n",
"Case 1: Solving with Python directly\n",
"\n",
"We will solve this problem in the following steps:\n",
"1. First, we simplify the inequality by expanding both sides and bringing all terms to one side.\n",
"2. Second, we find the critical points by solving the simplified equation.\n",
"3. Third, we test the sign of the simplified function with a number in each interval defined by the critical points.\n",
"4. Finally, we collect all the intervals where the inequality is satisfied.\n",
"\n",
"Here is the Python code to accomplish these:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve, simplify\n",
"\n",
"# Step 1: Simplify the inequality\n",
"x = symbols('x')\n",
"expr1 = simplify((2 * x + 10) * (x + 3))\n",
"expr2 = simplify((3 * x + 9) * (x + 8))\n",
"equation = simplify(expr1 - expr2)\n",
"\n",
"# Step 2: Find the critical points\n",
"critical_points = sorted(solve(Eq(equation, 0)))\n",
"\n",
"# Step 3 and 4: Test the sign of the simplified function for each interval\n",
"\n",
"# First, let's check for x in (-oo, first critical point)\n",
"test_point = critical_points[0] - 1\n",
"if equation.subs(x, test_point) < 0:\n",
" print(\"The inequality holds for x in (-oo, \" + str(critical_points[0]) + \")\")\n",
"\n",
"# Second, let's check for x in each (previous critical point, next critical point)\n",
"for i in range(len(critical_points) - 1):\n",
" test_point = (critical_points[i] + critical_points[i + 1]) / 2\n",
" if equation.subs(x, test_point) < 0:\n",
" print(\"The inequality holds for x in (\" + str(critical_points[i]) + \", \" + str(critical_points[i + 1]) + \")\")\n",
"\n",
"# Third, let's check for x in (last critical point, oo)\n",
"test_point = critical_points[-1] + 1\n",
"if equation.subs(x, test_point) < 0:\n",
" print(\"The inequality holds for x in (\" + str(critical_points[-1]) + \", oo)\")\n",
"\n",
"# The intervals output in the print statements represent the solution to the inequality in the problem.\n",
"```\n",
"\n",
"After running the above code, you will find the exact interval(s) that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"The inequality holds for x in (-oo, -14)\n",
"The inequality holds for x in (-3, oo)\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Great! So the solution to the inequality $(2x+10)(x+3)<(3x+9)(x+8)$ is given by the union of the two intervals where the inequality holds true. In interval notation, we can express the solution as:\n",
"\n",
"$$\\boxed{x \\in (-\\infty, -14) \\cup (-3, \\infty)}$$\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"# reset the assistant. Always reset the assistant before starting a new conversation.\n",
"assistant.reset()\n",
"\n",
"# given a math problem, we use the mathproxyagent to generate a prompt to be sent to the assistant as the initial message.\n",
"# the assistant receives the message and generates a response. The response will be sent back to the mathproxyagent for processing.\n",
"# The conversation continues until the termination condition is met, in MathChat, the termination condition is the detect of \"\\boxed{}\" in the response.\n",
"math_problem = \"Find all $x$ that satisfy the inequality $(2x+10)(x+3)<(3x+9)(x+8)$. Express your answer in interval notation.\"\n",
"mathproxyagent.initiate_chat(assistant, problem=math_problem)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example 2\n",
"Problem: For what negative value of $k$ is there exactly one solution to the system of equations \\begin{align*}\n",
"y &= 2x^2 + kx + 6 \\\\\n",
"y &= -x + 4?\n",
"\\end{align*}\n",
"\n",
"Correct Solution: Setting the two expressions for $y$ equal to each other, it follows that $2x^2 + kx + 6 = -x + 4$. Re-arranging, $2x^2 + (k+1)x + 2 = 0$. For there to be exactly one solution for $x$, then the discriminant of the given quadratic must be equal to zero. Thus, $(k+1)^2 - 4 \\cdot 2 \\cdot 2 = (k+1)^2 - 16 = 0$, so $k+1 = \\pm 4$. Taking the negative value, $k = \\boxed{-5}$."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mathproxyagent (to assistant):\n",
"\n",
"Let's use Python to solve a math problem.\n",
"\n",
"Query requirements:\n",
"You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
"You can use packages like sympy to help you.\n",
"You must follow the formats below to write your code:\n",
"```python\n",
"# your code\n",
"```\n",
"\n",
"First state the key idea to solve the problem. You may choose from three ways to solve the problem:\n",
"Case 1: If the problem can be solved with Python code directly, please write a program to solve it. You can enumerate all possible arrangements if needed.\n",
"Case 2: If the problem is mostly reasoning, you can solve it by yourself directly.\n",
"Case 3: If the problem cannot be handled in the above two ways, please follow this process:\n",
"1. Solve the problem step by step (do not over-divide the steps).\n",
"2. Take out any queries that can be asked through Python (for example, any calculations or equations that can be calculated).\n",
"3. Wait for me to give the results.\n",
"4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
"\n",
"After all the queries are run and you get the answer, put the answer in \\boxed{}.\n",
"\n",
"Problem:\n",
"For what negative value of $k$ is there exactly one solution to the system of equations \\begin{align*}\n",
"y &= 2x^2 + kx + 6 \\\\\n",
"y &= -x + 4?\n",
"\\end{align*}\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"The key idea to solve this problem is to set the two equations equal to each other and solve for x. Then we substitute x back into one of the equations to get the quadratic equation in terms of y. We set the discriminant of this quadratic equation equals to zero, because when discriminant (D = b² - 4ac) equals to zero in a quadratic equation, the equation has exactly one root. Solve the equation for k.\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"x, k = symbols('x k')\n",
"\n",
"# The two Y's are equal, so we set the two equations equal to each other\n",
"eq1 = Eq(2*x**2 + k*x + 6, -x + 4)\n",
"solve_for_x = solve(eq1, x)\n",
"\n",
"# Insert x solutions back to the original equation\n",
"eq2 = Eq(2*x + k, -x + 4)\n",
"result = []\n",
"\n",
"for solution in solve_for_x:\n",
" temp = eq2.subs(x, solution)\n",
" sol = solve(temp, k)\n",
" if sol != []:\n",
" result.append(sol)\n",
"\n",
"# The solution should be negative\n",
"negative_k = [s for s in result if s<0]\n",
"\n",
"print(negative_k)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"Error: Traceback (most recent call last):\n",
" negative_k = [s for s in result if s<0]\n",
" negative_k = [s for s in result if s<0]\n",
"TypeError: '<' not supported between instances of 'list' and 'int'\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Apologies for the oversight in the code. It seems I made a mistake in checking for the negative value of k. Let's modify the code to fix the issue.\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"x, k = symbols('x k')\n",
"\n",
"# The two Y's are equal, so we set the two equations equal to each other\n",
"eq1 = Eq(2*x**2 + k*x + 6, -x + 4)\n",
"solve_for_x = solve(eq1, x)\n",
"\n",
"# Insert x solutions back to the original equation\n",
"eq2 = Eq(2*x + k, -x + 4)\n",
"result = []\n",
"\n",
"for solution in solve_for_x:\n",
" temp = eq2.subs(x, solution)\n",
" sol = solve(temp, k)\n",
" if sol != []:\n",
" result.extend(sol)\n",
"\n",
"# The solution should be negative\n",
"negative_k = [s for s in result if s<0]\n",
"\n",
"print(negative_k)\n",
"```\n",
"This code should now properly identify the negative value of k for which there is exactly one solution to the system of equations.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"[-3*sqrt(33)/2 - 7/2]\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Great! Now we have the correct negative value of k for which there is exactly one solution to the system of equations. Therefore, the answer is:\n",
"\n",
"$$k = \\boxed{-\\frac{3\\sqrt{33}}{2}-\\frac{7}{2}}$$\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant.reset()\n",
"\n",
"math_problem = \"For what negative value of $k$ is there exactly one solution to the system of equations \\\\begin{align*}\\ny &= 2x^2 + kx + 6 \\\\\\\\\\ny &= -x + 4?\\n\\\\end{align*}\"\n",
"mathproxyagent.initiate_chat(assistant, problem=math_problem)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Example 3\n",
"Problem: Find all positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational. Express them in decreasing order, separated by commas.\n",
"\n",
"Correct Solution: For the roots to be real and rational, the discriminant must be a perfect square. Therefore, $(-7)^2-4 \\cdot 1 \\cdot c = 49-4c$ must be a perfect square. The only positive perfect squares less than 49 are $1$, $4$, $9$, $16$, $25$, and $36$. The perfect squares that give a integer value of $c$ are $1$, $9$, and $25$. Thus, we have the equations $49-4c=1$, $49-4c=9$, and $49-4c=25$. Solving, we get that the positive integer values of c are $\\boxed{12, 10, 6}$."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mathproxyagent (to assistant):\n",
"\n",
"Let's use Python to solve a math problem.\n",
"\n",
"Query requirements:\n",
"You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
"You can use packages like sympy to help you.\n",
"You must follow the formats below to write your code:\n",
"```python\n",
"# your code\n",
"```\n",
"\n",
"First state the key idea to solve the problem. You may choose from three ways to solve the problem:\n",
"Case 1: If the problem can be solved with Python code directly, please write a program to solve it. You can enumerate all possible arrangements if needed.\n",
"Case 2: If the problem is mostly reasoning, you can solve it by yourself directly.\n",
"Case 3: If the problem cannot be handled in the above two ways, please follow this process:\n",
"1. Solve the problem step by step (do not over-divide the steps).\n",
"2. Take out any queries that can be asked through Python (for example, any calculations or equations that can be calculated).\n",
"3. Wait for me to give the results.\n",
"4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
"\n",
"After all the queries are run and you get the answer, put the answer in \\boxed{}.\n",
"\n",
"Problem:\n",
"Find all positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational. Express them in decreasing order, separated by commas.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"To ensure the roots of the quadratic equation $x^2 - 7x + c = 0$ are real and rational, we will use two key ideas:\n",
"\n",
"1. The discriminant of the quadratic equation, $\\Delta = b^2 - 4ac$, must be non-negative so that the roots are real. \n",
"2. The discriminant must be a perfect square so that the roots are rational.\n",
"\n",
"We can find the values of $c$ using Python. First, we'll calculate the discriminant and check the conditions for each value of c. We only need to consider positive integers for $c$ until we reach a value where the discriminant becomes negative since after that point, there will be no real roots.\n",
"\n",
"```python\n",
"from sympy import *\n",
"\n",
"c_values = []\n",
"c = 1\n",
"\n",
"while True:\n",
" discriminant = Rational(49 - 4 * c)\n",
" if discriminant < 0:\n",
" break\n",
" if int(discriminant.sqrt())**2 == discriminant:\n",
" c_values.append(c)\n",
" c += 1\n",
"\n",
"print(c_values)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"Error: Traceback (most recent call last):\n",
" if int(discriminant.sqrt())**2 == discriminant:\n",
"AttributeError: 'Integer' object has no attribute 'sqrt'\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Apologies for the mistake. I should have used the `sqrt` function from the `sympy` library. Here's the corrected code:\n",
"\n",
"```python\n",
"from sympy import *\n",
"\n",
"c_values = []\n",
"c = 1\n",
"\n",
"while True:\n",
" discriminant = Rational(49 - 4 * c)\n",
" if discriminant < 0:\n",
" break\n",
" if int(sqrt(discriminant))**2 == discriminant:\n",
" c_values.append(c)\n",
" c += 1\n",
"\n",
"print(c_values)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"[6, 10, 12]\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"The positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational are \\boxed{12, 10, 6}.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant.reset()\n",
"\n",
"math_problem = \"Find all positive integer values of $c$ such that the equation $x^2-7x+c=0$ only has roots that are real and rational. Express them in decreasing order, separated by commas.\"\n",
"mathproxyagent.initiate_chat(assistant, problem=math_problem)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"------------------------------------\n",
"### Using other prompts\n",
"\n",
"MathChat allows different prompts that instruct assistant to solve the problem.\n",
"\n",
"Check out `MathUserProxyAgent.generate_init_message(problem, prompt_type='default', customized_prompt=None)`:\n",
"- You may choose from `['default', 'python', 'two_tools']` for parameter `prompt_type`. We include two more prompts in the paper: \n",
" 1. `'python'` is a simplified prompt from the default prompt that uses Python only. \n",
" 2. `'two_tools'` further allows the selection of Python or Wolfram Alpha based on this simplified `python` prompt. Note that this option requries a Wolfram Alpha API key and put it in `wolfram.txt`.\n",
"\n",
"- You can also input your customized prompt if needed: `mathproxyagent.generate_init_message(problem, customized_prompt=\"Your customized prompt\")`. Since this mathproxyagent detects '\\boxed{}' as termination, you need to have a similar termination sentence in the prompt: \"If you get the answer, put the answer in \\\\boxed{}.\". If the customized is provided, the `prompt_type` will be ignored.\n",
"\n",
"\n",
"### Example 4 (Use the \"python\" prompt):\n",
"\n",
"Problem: If $725x + 727y = 1500$ and $729x+ 731y = 1508$, what is the value of $x - y$ ?\n",
"\n",
"Correct Solution: Subtracting the two equations gives: \n",
"\\begin{align*}\n",
"(729x+731y)-(725x+727y) &= 1508-1500\\\\\n",
"\\Rightarrow\\qquad 4x+4y &= 8\\\\\n",
"\\Rightarrow\\qquad x+y &= 2.\n",
"\\end{align*}\n",
"\n",
"Multiplying this equation by 725 and subtracting this equation from $725x+727y=1500$ gives \\begin{align*}\n",
"(725x+727y) - 725(x+y) &= 1500-725(x+y) \\implies \\\\\n",
"2y &= 50.\n",
"\\end{align*}So we can write $x-y$ as $(x+y) - 2y$, which equals $2 - 50 = \\boxed{-48}$.\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mathproxyagent (to assistant):\n",
"\n",
"Let's use Python to solve a math problem.\n",
"\n",
"Query requirements:\n",
"You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.\n",
"You can use packages like sympy to help you.\n",
"You must follow the formats below to write your code:\n",
"```python\n",
"# your code\n",
"```\n",
"\n",
"Please follow this process:\n",
"1. Solve the problem step by step (do not over-divide the steps).\n",
"2. Take out any queries that can be asked through Python (for example, any calculations or equations that can be calculated).\n",
"3. Wait for me to give the results.\n",
"4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
"\n",
"After all the queries are run and you get the answer, put the answer in \\boxed{}.\n",
"\n",
"Problem: Problem: If $725x + 727y = 1500$ and $729x+ 731y = 1508$, what is the value of $x - y$ ?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Let's first solve the two equations as a system of linear equations. We can do this by either substitution or elimination method. Let's use the elimination method.\n",
"\n",
"We will multiply both equations by necessary multiples such that the coefficients of either x or y will cancel out when we subtract/add the two equations.\n",
"\n",
"We'll multiply the first equation by 729 and the second equation by -725. This will allow us to cancel out the x terms when we add the two equations.\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"x, y = symbols('x y')\n",
"equation1 = Eq(725 * x + 727 * y, 1500)\n",
"equation2 = Eq(729 * x + 731 * y, 1508)\n",
"\n",
"# Multiply equations with necessary multiples\n",
"equation1 = Eq(729 * (725 * x + 727 * y), 729 * 1500)\n",
"equation2 = Eq(-725 * (729 * x + 731 * y), -725 * 1508)\n",
"```\n",
"Now let's simplify the equations and add them up to cancel out the x terms.\n",
"\n",
"```python\n",
"equation1_simplified = Eq(729 * 725 * x + 729 * 727 * y, 729 * 1500)\n",
"equation2_simplified = Eq(-725 * 729 * x - 725 * 731 * y, -725 * 1508)\n",
"\n",
"# Add the two equations to cancel out the x terms\n",
"combined_equation = Eq(equation1_simplified.lhs + equation2_simplified.lhs, equation1_simplified.rhs + equation2_simplified.rhs)\n",
"```\n",
"Now let's solve the combined equation for y.\n",
"\n",
"```python\n",
"y_solution = solve(combined_equation, y)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"No output found. Make sure you print the results.\n",
"Error: Traceback (most recent call last):\n",
" equation1_simplified = Eq(729 * 725 * x + 729 * 727 * y, 729 * 1500)\n",
"NameError: name 'x' is not defined\n",
"\n",
"Error: Traceback (most recent call last):\n",
" y_solution = solve(combined_equation, y)\n",
"NameError: name 'combined_equation' is not defined\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"I apologize for not providing the output. I'll correct the code and print the results for you.\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"x, y = symbols('x y')\n",
"equation1 = Eq(725 * x + 727 * y, 1500)\n",
"equation2 = Eq(729 * x + 731 * y, 1508)\n",
"\n",
"# Multiply equations with necessary multiples\n",
"equation1 = Eq(729 * (725 * x + 727 * y), 729 * 1500)\n",
"equation2 = Eq(-725 * (729 * x + 731 * y), -725 * 1508)\n",
"\n",
"equation1_simplified = Eq(729 * 725 * x + 729 * 727 * y, 729 * 1500)\n",
"equation2_simplified = Eq(-725 * 729 * x - 725 * 731 * y, -725 * 1508)\n",
"\n",
"# Add the two equations to cancel out the x terms\n",
"combined_equation = Eq(equation1_simplified.lhs + equation2_simplified.lhs, equation1_simplified.rhs + equation2_simplified.rhs)\n",
"\n",
"# Solve the combined equation for y\n",
"y_solution = solve(combined_equation, y)\n",
"print(y_solution)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"[25]\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Great, we have found the solution for y: $y = 25$. Now, let's substitute the y value back into one of the original equations to solve for x. We can use the first equation for this.\n",
"\n",
"```python\n",
"# Substitute y in the first equation\n",
"equation1_substituted = equation1.subs(y, 25)\n",
"\n",
"# Solve the equation for x\n",
"x_solution = solve(equation1_substituted, x)\n",
"print(x_solution)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"[-23]\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Now we have found the solution for x: $x = -23$. To find the value of $x - y$, we can simply subtract y from x.\n",
"\n",
"```python\n",
"x_minus_y = x_solution[0] - y_solution[0]\n",
"print(x_minus_y)\n",
"```\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"-48\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"The value of $x - y$ is $-48$. So the final answer is $\\boxed{-48}$.\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant.reset() # clear LLM assistant's message history\n",
"\n",
"# we set the prompt_type to \"python\", which is a simplied version of the default prompt.\n",
"math_problem = \"Problem: If $725x + 727y = 1500$ and $729x+ 731y = 1508$, what is the value of $x - y$ ?\"\n",
"mathproxyagent.initiate_chat(assistant, problem=math_problem, prompt_type=\"python\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example 5 (Use the \"two_tools\" prompt)\n",
"\n",
"Problem: Find all numbers $a$ for which the graph of $y=x^2+a$ and the graph of $y=ax$ intersect. Express your answer in interval notation.\n",
"\n",
"\n",
"Correct Solution: If these two graphs intersect then the points of intersection occur when \\[x^2+a=ax,\\] or \\[x^2-ax+a=0.\\] This quadratic has solutions exactly when the discriminant is nonnegative: \\[(-a)^2-4\\cdot1\\cdot a\\geq0.\\] This simplifies to \\[a(a-4)\\geq0.\\] This quadratic (in $a$) is nonnegative when $a$ and $a-4$ are either both $\\ge 0$ or both $\\le 0$. This is true for $a$ in $$(-\\infty,0]\\cup[4,\\infty).$$ Therefore the line and quadratic intersect exactly when $a$ is in $\\boxed{(-\\infty,0]\\cup[4,\\infty)}$.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mathproxyagent (to assistant):\n",
"\n",
"Let's use two tools (Python and Wolfram alpha) to solve a math problem.\n",
"\n",
"Query requirements:\n",
"You must follow the formats below to write your query:\n",
"For Wolfram Alpha:\n",
"```wolfram\n",
"# one wolfram query\n",
"```\n",
"For Python:\n",
"```python\n",
"# your code\n",
"```\n",
"When using Python, you should always use the 'print' function for the output and use fractions/radical forms instead of decimals. You can use packages like sympy to help you.\n",
"When using wolfram, give one query in each code block.\n",
"\n",
"Please follow this process:\n",
"1. Solve the problem step by step (do not over-divide the steps).\n",
"2. Take out any queries that can be asked through Python or Wolfram Alpha, select the most suitable tool to be used (for example, any calculations or equations that can be calculated).\n",
"3. Wait for me to give the results.\n",
"4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.\n",
"\n",
"After all the queries are run and you get the answer, put the final answer in \\boxed{}.\n",
"\n",
"Problem: Find all numbers $a$ for which the graph of $y=x^2+a$ and the graph of $y=ax$ intersect. Express your answer in interval notation.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"First, let's find the intersection points of the two graphs by setting $y = x^2 + a$ equal to $y = ax$. \n",
"\n",
"$x^2 + a = ax$\n",
"\n",
"Next, let's rearrange it to a quadratic equation:\n",
"\n",
"$x^2 - ax + a = 0$\n",
"\n",
"We will find the discriminant, $\\Delta$, using the following formula: \n",
"$\\Delta = b^2 - 4ac$, where $a = 1$, $b = -a$ and $c = a$.\n",
"\n",
"Now, let's calculate the discriminant:\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"Sure, let's calculate the discriminant:\n",
"\n",
"$\\Delta = (-a)^2 - 4(1)(a) = a^2 - 4a$\n",
"\n",
"For the quadratic equation to have at least one real solution, the discriminant must be greater than or equal to zero:\n",
"\n",
"$a^2 - 4a \\ge 0$\n",
"\n",
"Now, let's find the values of $a$ that satisfy this inequality. We can do this by factoring the quadratic:\n",
"\n",
"$a(a - 4) \\ge 0$\n",
"\n",
"This inequality is true whenever $a \\le 0$ or $a - 4 \\ge 0$. Now let's find the intervals for $a$.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"mathproxyagent (to assistant):\n",
"\n",
"Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to mathproxyagent):\n",
"\n",
"The inequality $a \\le 0$ represents the interval $(-\\infty, 0]$. \n",
"\n",
"The inequality $a - 4 \\ge 0$ can be rewritten as $a \\ge 4$, which represents the interval $[4, \\infty)$. \n",
"\n",
"Since we are looking for the values of $a$ where the graphs intersect, we need to consider both intervals. Therefore, the final answer would be the union of the two intervals:\n",
"\n",
"\\[\\boxed{(-\\infty, 0] \\cup [4, \\infty)}\\]\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
}
],
"source": [
"assistant.reset() # clear LLM assistant's message history\n",
"\n",
"# The wolfram alpha appid is required for this example (the assistant may choose to query Wolfram Alpha).\n",
"import os\n",
"if \"WOLFRAM_ALPHA_APPID\" not in os.environ:\n",
" os.environ[\"WOLFRAM_ALPHA_APPID\"] = open(\"wolfram.txt\").read().strip()\n",
"\n",
"# we set the prompt_type to \"two_tools\", which allows the assistant to select wolfram alpha when necessary.\n",
"math_problem = \"Find all numbers $a$ for which the graph of $y=x^2+a$ and the graph of $y=ax$ intersect. Express your answer in interval notation.\"\n",
"mathproxyagent.initiate_chat(assistant, problem=math_problem, prompt_type=\"two_tools\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "flaml_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agent_planning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agentchat_planning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
@ -17,11 +17,11 @@
}
},
"source": [
"# Multi-Agent Communication and Planning\n",
"# Auto Generated Agent Chat: Collaborative Task Solving with Coding and Planning Agent\n",
"\n",
"FLAML offers an experimental feature of interactive LLM agents, which can be used to solve various tasks with human or automatic feedback, including tasks that require using tools via code. Please find documentation about this feature [here](https://microsoft.github.io/FLAML/docs/Use-Cases/Auto-Generation#agents-experimental).\n",
"\n",
"In this notebook, we demonstrate how to use multiple agents to work together and accomplish a task which requires finding info from the web and coding. `AssistantAgent` is an LLM-based agent that can write Python code (in a Python coding block) for a user to execute for a given task. `UserProxyAgent` is an agent which serves as a proxy for a user to execute the code written by `AssistantAgent`. We further create a planning agent for the assistant agent to consult. The planning agent is a variation of the `AssistantAgent`.\n",
"In this notebook, we demonstrate how to use multiple agents to work together and accomplish a task which requires finding info from the web and coding. `AssistantAgent` is an LLM-based agent that can write and debug Python code (in a Python coding block) for a user to execute for a given task. `UserProxyAgent` is an agent which serves as a proxy for a user to execute the code written by `AssistantAgent`. We further create a planning agent for the assistant agent to consult. The planning agent is a variation of the LLM-based `AssistantAgent` with a different system message.\n",
"\n",
"## Requirements\n",
"\n",
@ -123,11 +123,11 @@
"metadata": {},
"outputs": [],
"source": [
"from flaml.autogen.agent import AssistantAgent, UserProxyAgent\n",
"from flaml.autogen.agentchat import AssistantAgent, UserProxyAgent\n",
"\n",
"planner = AssistantAgent(\n",
" name=\"planner\",\n",
" config_list=config_list,\n",
" oai_config={\"config_list\": config_list},\n",
" # the default system message of the AssistantAgent is overwritten here\n",
" system_message=\"You are a helpful AI assistant. You suggest coding and reasoning steps for another AI assistant to accomplish a task. Do not suggest concrete code. For any action beyond writing code or reasoning, convert it to a step which can be implemented by writing code. For example, the action of browsing the web can be implemented by writing code which reads and prints the content of a web page. Finally, inspect the execution result. If the plan is not good, suggest a better plan. If the execution is wrong, analyze the error and suggest a fix.\"\n",
")\n",
@ -139,7 +139,8 @@
"\n",
"def ask_planner(message):\n",
" planner_user.initiate_chat(planner, message=message)\n",
" return planner.oai_conversations[planner_user.name][-1][\"content\"]\n"
" # return the last message received from the planner\n",
" return planner_user.last_message()[\"content\"]\n"
]
},
{
@ -168,7 +169,7 @@
" \"functions\": [\n",
" {\n",
" \"name\": \"ask_planner\",\n",
" \"description\": \"ask planner to: 1. get a plan, 2. verify the execution result of the plan and potentially suggest new plan.\",\n",
" \"description\": \"ask planner to: 1. get a plan for finishing a task, 2. verify the execution result of the plan and potentially suggest new plan.\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
@ -184,9 +185,9 @@
" }\n",
")\n",
"\n",
"# create a UserProxyAgent instance named \"user\"\n",
"user = UserProxyAgent(\n",
" name=\"user\",\n",
"# create a UserProxyAgent instance named \"user_proxy\"\n",
"user_proxy = UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" human_input_mode=\"TERMINATE\",\n",
" max_consecutive_auto_reply=10,\n",
" # is_termination_msg=lambda x: \"content\" in x and x[\"content\"] is not None and x[\"content\"].rstrip().endswith(\"TERMINATE\"),\n",
@ -214,12 +215,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"user (to assistant):\n",
"user_proxy (to assistant):\n",
"\n",
"Suggest a fix to an open good first issue of flaml\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"To suggest a fix to an open good first issue of FLAML, we first need to fetch the list of open issues labeled as \"good first issue\" from the FLAML GitHub repository. We can do this using the GitHub API. Here is a Python script that fetches and prints the list of such issues.\n",
"\n",
@ -232,12 +233,10 @@
"def fetch_issues():\n",
" url = \"https://api.github.com/repos/microsoft/FLAML/issues\"\n",
" response = requests.get(url, params={\"state\": \"open\", \"labels\": \"good first issue\"})\n",
" if response.status_code == 200:\n",
" issues = json.loads(response.text)\n",
" for issue in issues:\n",
" print(f\"Issue #{issue['number']}: {issue['title']}\")\n",
" else:\n",
" print(f\"Failed to fetch issues. HTTP status code: {response.status_code}\")\n",
" issues = response.json()\n",
"\n",
" for issue in issues:\n",
" print(f\"Issue #{issue['number']}: {issue['title']}\")\n",
"\n",
"fetch_issues()\n",
"```\n",
@ -246,13 +245,18 @@
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
">>>>>>>> USING AUTO REPLY...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"user_proxy (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Issue #1120: use_label_encoder warning with xgboost\n",
"Issue #1099: Do not print data split method when user provides X_val in AutoML\n",
"Issue #1034: Use appropriate wait time for retry based on the error message. \n",
"Issue #1029: Issues with Adding Custom APIs in Auto Generation\n",
"Issue #981: Running flaml[tune] using \"-O\" flag for python interpreter (optimization - disables assertions) crashes\n",
@ -274,23 +278,23 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"***** Suggested function Call: ask_planner *****\n",
"Arguments: \n",
"{\n",
" \"message\": \"Here are the open good first issues of FLAML: \\nIssue #1120: use_label_encoder warning with xgboost\\nIssue #1099: Do not print data split method when user provides X_val in AutoML\\nIssue #1034: Use appropriate wait time for retry based on the error message.\\nIssue #1029: Issues with Adding Custom APIs in Auto Generation\\nIssue #981: Running flaml[tune] using \\\"-O\\\" flag for python interpreter (optimization - disables assertions) crashes\\nIssue #903: Conditional parameter flow2 crash\\nIssue #884: indentation space\\nIssue #882: Check if openml version is required\\nIssue #834: Adjust the indent\\nIssue #821: pip install flaml FAIL\\nIssue #807: Isolate the ensemble part and expose it to users\\nIssue #805: how to pass categorical features names or indices to learner\\nIssue #785: Flaml/LightGBM - Shouldn't I found better/faster or equal results from FLAML than direct LightGBM?\\nIssue #764: Add an announcement of the discord channel\\nIssue #748: Documentation about small budget\\nIssue #737: Make zero-shot automl more discoverable\\nIssue #509: New HCrystalBall release\\nIssue #429: samples about conversion to ONNX\\nIssue #413: support anomaly detection\\nIssue #304: CatBoost Fails with Keyword 'groups'\\n\\nPlease suggest a fix for one of these issues.\"\n",
" \"message\": \"Here are the open good first issues of FLAML: \\n\\nIssue #1120: use_label_encoder warning with xgboost\\nIssue #1034: Use appropriate wait time for retry based on the error message.\\nIssue #1029: Issues with Adding Custom APIs in Auto Generation\\nIssue #981: Running flaml[tune] using \\\"-O\\\" flag for python interpreter (optimization - disables assertions) crashes\\nIssue #903: Conditional parameter flow2 crash\\nIssue #884: indentation space\\nIssue #882: Check if openml version is required\\nIssue #834: Adjust the indent\\nIssue #821: pip install flaml FAIL\\nIssue #807: Isolate the ensemble part and expose it to users\\nIssue #805: how to pass categorical features names or indices to learner\\nIssue #785: Flaml/LightGBM - Shouldn't I found better/faster or equal results from FLAML than direct LightGBM?\\nIssue #764: Add an announcement of the discord channel\\nIssue #748: Documentation about small budget\\nIssue #737: Make zero-shot automl more discoverable\\nIssue #509: New HCrystalBall release\\nIssue #429: samples about conversion to ONNX\\nIssue #413: support anomaly detection\\nIssue #304: CatBoost Fails with Keyword 'groups'\\n\\nPlease suggest a fix for one of these issues.\"\n",
"}\n",
"************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
">>>>>>>> USING AUTO REPLY...\n",
"planner_user (to planner):\n",
"\n",
"Here are the open good first issues of FLAML: \n",
"\n",
"Issue #1120: use_label_encoder warning with xgboost\n",
"Issue #1099: Do not print data split method when user provides X_val in AutoML\n",
"Issue #1034: Use appropriate wait time for retry based on the error message.\n",
"Issue #1029: Issues with Adding Custom APIs in Auto Generation\n",
"Issue #981: Running flaml[tune] using \"-O\" flag for python interpreter (optimization - disables assertions) crashes\n",
@ -315,141 +319,121 @@
"--------------------------------------------------------------------------------\n",
"planner (to planner_user):\n",
"\n",
"To provide appropriate suggestions for fixes, we first need to isolate a problem, understand it, and then propose a solution. Let's consider this issue as an example: \n",
"Without certain details about the issues, it's a bit challenging to suggest explicit fixes. However, I'll give you a roadmap to tackle some of these:\n",
"\n",
"Issue #1099: Do not print data split method when user provides X_val in AutoML\n",
"1. **Issue #1120: use_label_encoder warning with xgboost**\n",
"\n",
"Plan: \n",
"1. Understand the Source of the Problem: Study the part of the code where the issue originates. From the issue title, it seems like in the FLAML AutoML component, a data split method is being printed even when a user has provided a pre-split validation dataset.\n",
"2. Analyze the Code: If the user has already provided a validation data set (X_val), there would be no requirement for another data split. Therefore, it involves understanding why, nonetheless, a data split is printed. To solve this, the AI might have to go through the process where the software splits the data or loads the validation dataset, if available.\n",
"3. Propose a Solution: The solution could involve incorporating a conditional statement that checks if X_val is not null. If it isn't, then avoid printing the data split method. \n",
"4. Test the Solution: Once the code is implemented, it is important to test the solution by providing a validation dataset and ensuring that the data split information is no longer printed.\n",
"5. Inspect the result: If the proposed plan doesn't work as expected, analyze the issue and try to develop a better plan. If the execution is wrong, try to understand the error and adapt the approach accordingly.\n",
"Reasoning: It appears like the `use_label_encoder` warning stems from deprecated usage in the XGBoost module that FLAML utilizes. \n",
"\n",
"Plan: Update the call to the XGBoost module in FLAML to either stop using the label encoder or specify explicitly that the label encoder should be used, based on how FLAML is currently set up.\n",
"\n",
"2. **Issue #903: Conditional parameter flow2 crash**\n",
"\n",
"Reasoning: This is likely a bug in the code where a condition is not properly handled.\n",
"\n",
"Plan: Identify the conditions causing the crash, and update the code to handle the condition correctly.\n",
"\n",
"3. **Issue #882: Check if openml version is required**\n",
"\n",
"Reasoning: The issue suggests that FLAML may be compatible with multiple versions of the `openml` library, but this isn't specified anywhere.\n",
"\n",
"Plan: Identify the minimum required version of the `openml` library that FLAML can operate with. If there's no minimum version, then there's no issue. If there is, update the documentation and the package requirements to show this.\n",
"\n",
"4. **Issue #821: pip install flaml FAIL**\n",
"\n",
"Reasoning: Users are unable to install FLAML using pip. This could be due to unmet dependancies, or issues with the package itself.\n",
"\n",
"Plan: Investigate the specific errors when trying to install via pip. If dependencies are missing or incompatible, update the package requirements. If there are issues with the package, fix those issues.\n",
"\n",
"5. **Issue #805: how to pass categorical features names or indices to learner**\n",
"\n",
"Reasoning: It seems users are facing difficulty in passing categorical features. This could be due to unclear or incorrect documentation, or missing functionality.\n",
"\n",
"Plan: Inspect the current documentation and functionality. Update the documentation to clarify how to pass these features, and if the functionality is missing, implement the functionality.\n",
"\n",
"Remember to thoroughly inspect the execution result for each issue to ensure your fix is effective. If the plan does not appear to fix the issue, adjust your plan based on the error message or unexpected behavior.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"user (to assistant):\n",
"user_proxy (to assistant):\n",
"\n",
"***** Response from calling function \"ask_planner\" *****\n",
"To provide appropriate suggestions for fixes, we first need to isolate a problem, understand it, and then propose a solution. Let's consider this issue as an example: \n",
"Without certain details about the issues, it's a bit challenging to suggest explicit fixes. However, I'll give you a roadmap to tackle some of these:\n",
"\n",
"Issue #1099: Do not print data split method when user provides X_val in AutoML\n",
"1. **Issue #1120: use_label_encoder warning with xgboost**\n",
"\n",
"Plan: \n",
"1. Understand the Source of the Problem: Study the part of the code where the issue originates. From the issue title, it seems like in the FLAML AutoML component, a data split method is being printed even when a user has provided a pre-split validation dataset.\n",
"2. Analyze the Code: If the user has already provided a validation data set (X_val), there would be no requirement for another data split. Therefore, it involves understanding why, nonetheless, a data split is printed. To solve this, the AI might have to go through the process where the software splits the data or loads the validation dataset, if available.\n",
"3. Propose a Solution: The solution could involve incorporating a conditional statement that checks if X_val is not null. If it isn't, then avoid printing the data split method. \n",
"4. Test the Solution: Once the code is implemented, it is important to test the solution by providing a validation dataset and ensuring that the data split information is no longer printed.\n",
"5. Inspect the result: If the proposed plan doesn't work as expected, analyze the issue and try to develop a better plan. If the execution is wrong, try to understand the error and adapt the approach accordingly.\n",
"Reasoning: It appears like the `use_label_encoder` warning stems from deprecated usage in the XGBoost module that FLAML utilizes. \n",
"\n",
"Plan: Update the call to the XGBoost module in FLAML to either stop using the label encoder or specify explicitly that the label encoder should be used, based on how FLAML is currently set up.\n",
"\n",
"2. **Issue #903: Conditional parameter flow2 crash**\n",
"\n",
"Reasoning: This is likely a bug in the code where a condition is not properly handled.\n",
"\n",
"Plan: Identify the conditions causing the crash, and update the code to handle the condition correctly.\n",
"\n",
"3. **Issue #882: Check if openml version is required**\n",
"\n",
"Reasoning: The issue suggests that FLAML may be compatible with multiple versions of the `openml` library, but this isn't specified anywhere.\n",
"\n",
"Plan: Identify the minimum required version of the `openml` library that FLAML can operate with. If there's no minimum version, then there's no issue. If there is, update the documentation and the package requirements to show this.\n",
"\n",
"4. **Issue #821: pip install flaml FAIL**\n",
"\n",
"Reasoning: Users are unable to install FLAML using pip. This could be due to unmet dependancies, or issues with the package itself.\n",
"\n",
"Plan: Investigate the specific errors when trying to install via pip. If dependencies are missing or incompatible, update the package requirements. If there are issues with the package, fix those issues.\n",
"\n",
"5. **Issue #805: how to pass categorical features names or indices to learner**\n",
"\n",
"Reasoning: It seems users are facing difficulty in passing categorical features. This could be due to unclear or incorrect documentation, or missing functionality.\n",
"\n",
"Plan: Inspect the current documentation and functionality. Update the documentation to clarify how to pass these features, and if the functionality is missing, implement the functionality.\n",
"\n",
"Remember to thoroughly inspect the execution result for each issue to ensure your fix is effective. If the plan does not appear to fix the issue, adjust your plan based on the error message or unexpected behavior.\n",
"********************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"Based on the plan, here is a Python script that you can use to fix the issue #1099. This script assumes that the data split method is printed in a function called `split_data` in a file called `automl.py` (you need to replace these with the actual function and file names). \n",
"Based on the planner's suggestions, let's tackle the first issue: \"Issue #1120: use_label_encoder warning with xgboost\". \n",
"\n",
"```python\n",
"# filename: fix_issue_1099.py\n",
"The warning is likely due to a deprecated usage in the XGBoost module that FLAML utilizes. To fix this, we need to update the call to the XGBoost module in FLAML to either stop using the label encoder or specify explicitly that the label encoder should be used, based on how FLAML is currently set up.\n",
"\n",
"def split_data(X, y, X_val=None):\n",
" if X_val is not None:\n",
" # If validation data is provided, use it without splitting\n",
" print(\"Using provided validation data\")\n",
" return X, y, X_val\n",
" else:\n",
" # If validation data is not provided, split the data and print the split method\n",
" print(\"Splitting data using method X\")\n",
" # Code to split the data goes here\n",
" return X_train, y_train, X_val\n",
"```\n",
"Here is a general approach to fix this issue:\n",
"\n",
"Please replace the function name and file name with the actual ones, save this script as `fix_issue_1099.py`, and run it. This script will fix the issue by not printing the data split method when validation data is provided.\n",
"1. Clone the FLAML repository to your local machine.\n",
"2. Navigate to the file where XGBoost is being called.\n",
"3. Look for the `use_label_encoder` parameter in the XGBoost function call.\n",
"4. If `use_label_encoder` is set to True, you can try setting it to False to see if the warning disappears.\n",
"5. If the warning persists, it might be necessary to refactor the code to avoid using label encoding in XGBoost, or to handle label encoding before data is passed to XGBoost.\n",
"6. After making changes, run the tests to make sure your changes do not break anything.\n",
"7. If everything is fine, commit your changes and create a pull request.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"\n",
"***** Suggested function Call: ask_planner *****\n",
"Arguments: \n",
"{\n",
" \"message\": \"The user has executed the suggested code for fixing issue #1099 in FLAML. The code executed successfully without any output. Please suggest the next steps.\"\n",
"}\n",
"************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"planner_user (to planner):\n",
"\n",
"The user has executed the suggested code for fixing issue #1099 in FLAML. The code executed successfully without any output. Please suggest the next steps.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"planner (to planner_user):\n",
"\n",
"Great, if the code executed successfully, we can continue with the following steps:\n",
"\n",
"1. Code Review: Ensure the code is conforming to the overall structure, style, and standards of the FLAML project. Check if the changes are minimal, only related to the specific issue, and not affecting other parts of the codebase. \n",
"\n",
"2. Testing: Now, it's time to test the code under different scenarios. Here are some test cases you might consider:\n",
"\n",
" A. Provide a test when X_val is not provided. The code should behave in the same way as before your changes, including printing data split information.\n",
"\n",
" B. Provide a test where X_val is provided. The updated code should no longer print data split information.\n",
"\n",
" C. Test with different types of data (real numbers, categorical variables, missing values, etc.) for robustness. \n",
" \n",
"3. Review Error Messages or Unexpected Behavior: If anything doesn't work as expected, understand the cause and modify the code as needed.\n",
"\n",
"4. Documentation: If changes made to the code are substantial or might affect the way users interact with FLAML, properly document these changes. Make sure to clarify that if 'X_val' is provided, the data split method will not be printed.\n",
"\n",
"5. Submit the Code: When you're sure the changes fix the issue and don't introduce new ones, it would be the time to submit your code for review. Depending on the project guidelines, you may need to fork the repository, commit your changes, push the commit(s) to your fork, and then open a pull request.\n",
"\n",
"6. Wait for Feedback: Once the code is submitted, wait for the maintainers' feedback. They might suggest some changes, enhancements, or approve the changes right away. Keep an open communication channel and respond to the comments, making the suggested changes. You might need to go through a few iterations of this process until the code is accepted and the issue is closed.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"user (to assistant):\n",
"\n",
"***** Response from calling function \"ask_planner\" *****\n",
"Great, if the code executed successfully, we can continue with the following steps:\n",
"\n",
"1. Code Review: Ensure the code is conforming to the overall structure, style, and standards of the FLAML project. Check if the changes are minimal, only related to the specific issue, and not affecting other parts of the codebase. \n",
"\n",
"2. Testing: Now, it's time to test the code under different scenarios. Here are some test cases you might consider:\n",
"\n",
" A. Provide a test when X_val is not provided. The code should behave in the same way as before your changes, including printing data split information.\n",
"\n",
" B. Provide a test where X_val is provided. The updated code should no longer print data split information.\n",
"\n",
" C. Test with different types of data (real numbers, categorical variables, missing values, etc.) for robustness. \n",
" \n",
"3. Review Error Messages or Unexpected Behavior: If anything doesn't work as expected, understand the cause and modify the code as needed.\n",
"\n",
"4. Documentation: If changes made to the code are substantial or might affect the way users interact with FLAML, properly document these changes. Make sure to clarify that if 'X_val' is provided, the data split method will not be printed.\n",
"\n",
"5. Submit the Code: When you're sure the changes fix the issue and don't introduce new ones, it would be the time to submit your code for review. Depending on the project guidelines, you may need to fork the repository, commit your changes, push the commit(s) to your fork, and then open a pull request.\n",
"\n",
"6. Wait for Feedback: Once the code is submitted, wait for the maintainers' feedback. They might suggest some changes, enhancements, or approve the changes right away. Keep an open communication channel and respond to the comments, making the suggested changes. You might need to go through a few iterations of this process until the code is accepted and the issue is closed.\n",
"********************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"Please note that without the actual codebase and the specific details of the issue, this is a general approach and might need to be adjusted based on the actual code and setup of FLAML. \n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY...\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED.\n"
]
}
],
"source": [
"# the assistant receives a message from the user, which contains the task description\n",
"user.initiate_chat(\n",
"user_proxy.initiate_chat(\n",
" assistant,\n",
" message=\"\"\"Suggest a fix to an open good first issue of flaml\"\"\",\n",
")"
@ -482,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.17"
},
"vscode": {
"interpreter": {

View File

@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agent_two_users.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agentchat_two_users.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
@ -17,9 +17,9 @@
}
},
"source": [
"# Multi-Agent Human-in-the-loop Application\n",
"# Auto Generated Agent Chat: Collaborative Task Solving with Multiple Agents and Human Users\n",
"\n",
"FLAML offers an experimental feature of interactive LLM agents, which can be used to solve various tasks with human or automatic feedback, including tasks that require using tools via code.\n",
"FLAML offers an experimental feature of interactive LLM agents, which can be used to solve various tasks with human or automatic feedback, including tasks that require using tools via code. Please find documentation about this feature [here](https://microsoft.github.io/FLAML/docs/Use-Cases/Auto-Generation#agents-experimental).\n",
"\n",
"In this notebook, we demonstrate an application involving multiple agents and human users to work together and accomplish a task. `AssistantAgent` is an LLM-based agent that can write Python code (in a Python coding block) for a user to execute for a given task. `UserProxyAgent` is an agent which serves as a proxy for a user to execute the code written by `AssistantAgent`. We create multiple `UserProxyAgent` instances which can represent different human users.\n",
"\n",
@ -123,13 +123,15 @@
"metadata": {},
"outputs": [],
"source": [
"from flaml.autogen.agent import AssistantAgent, UserProxyAgent\n",
"from flaml.autogen.agentchat import AssistantAgent, UserProxyAgent\n",
"\n",
"def ask_expert(message):\n",
" assistant_for_expert = AssistantAgent(\n",
" name=\"assistant_for_expert\",\n",
" temperature=0,\n",
" config_list=config_list,\n",
" oai_config={\n",
" \"temperature\": 0,\n",
" \"config_list\": config_list,\n",
" },\n",
" )\n",
" expert = UserProxyAgent(\n",
" name=\"expert\",\n",
@ -139,9 +141,10 @@
"\n",
" expert.initiate_chat(assistant_for_expert, message=message)\n",
" expert.human_input_mode, expert.max_consecutive_auto_reply = \"NEVER\", 0\n",
" # final message from the expert\n",
" # final message sent from the expert\n",
" expert.send(\"summarize the solution\", assistant_for_expert)\n",
" return assistant_for_expert.oai_conversations[expert.name][-1][\"content\"]\n"
" # return the last message the expert received\n",
" return expert.last_message()[\"content\"]\n"
]
},
{
@ -220,23 +223,23 @@
"student (to assistant_for_student):\n",
"\n",
"Find $a + b + c$, given that $x+y \\neq -1$ and \n",
" \\begin{align}\n",
"\\begin{align}\n",
"\tax + by + c & = x + 7,\\\n",
"\ta + bx + cy & = 2x + 6y,\\\n",
"\tay + b + cx & = 4x + y.\n",
"\t\\end{align}.\n",
"\\end{align}.\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"This is a system of linear equations. We can solve it using Python's sympy library, which provides a method to solve systems of equations. Here is the Python code to solve it:\n",
"This is a system of linear equations. We can solve it using Python's sympy library, which provides a function to solve systems of equations. Here is the Python code to solve it:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"a, b, c, x, y = symbols('a b c x y')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
@ -256,7 +259,7 @@
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n"
">>>>>>>> USING AUTO REPLY FOR THE USER...\n"
]
},
{
@ -273,47 +276,221 @@
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"The result you got is a symbolic expression in terms of x and y. This is because the system of equations is underdetermined, meaning there are more variables (a, b, c, x, y) than there are equations. Therefore, the solution is not unique and depends on the values of x and y.\n",
"The output is the expression for $a + b + c$ in terms of $x$ and $y$. It's a bit complex, but it's the correct solution given the system of equations. \n",
"\n",
"If you have specific values for x and y, you can substitute them into the expression to get a numerical result. Here is how you can do it:\n",
"If you have specific values for $x$ and $y$ (with the condition $x + y \\neq -1$), you can substitute them into this expression to get a numerical value for $a + b + c$.\n",
"\n",
"```python\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"If you don't have specific values for $x$ and $y$, this is the most simplified form of $a + b + c$ that we can get from the given system of equations.\n",
"\n",
"# substitute the values into the result\n",
"numerical_result = result.subs({x: x_val, y: y_val})\n",
"\n",
"print(numerical_result)\n",
"```\n",
"\n",
"Please replace `x_val` and `y_val` with your specific values and run the code to get a numerical result for a + b + c.\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/student/tmp_code_f877237249a9eef733c6da775235695a.py\", line 6, in <module>\n",
" numerical_result = result.subs({x: x_val, y: y_val})\n",
"NameError: name 'result' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"I apologize for the confusion. The variable `result` was defined in the previous code block and is not available in the current one. We need to combine the two code blocks into one. Here is the corrected code:\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"student (to assistant_for_student):\n",
"\n",
"I'm not sure if the answer is correct\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"***** Suggested function Call: ask_expert *****\n",
"Arguments: \n",
"{\n",
" \"message\": \"The user has a system of equations and wants to find the sum of the variables a, b, and c. The system of equations is as follows:\\n\\nax + by + c = x + 7,\\na + bx + cy = 2x + 6y,\\nay + b + cx = 4x + y.\\n\\nI used sympy to solve the system of equations and got the following expression for a + b + c:\\n\\n(2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\\n\\nThe user is not sure if the answer is correct. Could you please verify this?\"\n",
"}\n",
"***********************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"The user has a system of equations and wants to find the sum of the variables a, b, and c. The system of equations is as follows:\n",
"\n",
"ax + by + c = x + 7,\n",
"a + bx + cy = 2x + 6y,\n",
"ay + b + cx = 4x + y.\n",
"\n",
"I used sympy to solve the system of equations and got the following expression for a + b + c:\n",
"\n",
"(2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"\n",
"The user is not sure if the answer is correct. Could you please verify this?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Sure, I can help you verify the solution. Let's first define the given system of equations and the expression for a + b + c using sympy, and then we'll check if the expression is correct by substituting it back into the original equations. Here's the code to do that:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"abc_expression = (2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"\n",
"solutions = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"abc_sum = sum(solutions.values())\n",
"\n",
"print(\"abc_expression:\", abc_expression)\n",
"print(\"abc_sum:\", abc_sum)\n",
"print(\"abc_expression equals abc_sum:\", abc_expression.equals(abc_sum))\n",
"```\n",
"\n",
"Please execute the above code, and we'll see if the expression for a + b + c is correct.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"abc_expression: (2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"abc_sum: (2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"abc_expression equals abc_sum: True\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Great! The output shows that the expression for a + b + c is indeed correct, as the abc_expression equals the abc_sum. You can trust the solution provided by sympy.\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"expert (to assistant_for_expert):\n",
"\n",
"can you simplify it?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Sure, we can simplify the expression using sympy's simplify function. Here's the code to do that:\n",
"\n",
"```python\n",
"from sympy import simplify\n",
"\n",
"simplified_expression = simplify(abc_expression)\n",
"\n",
"print(\"Simplified expression:\", simplified_expression)\n",
"```\n",
"\n",
"Please execute the above code to get the simplified expression.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/expert/tmp_code_fc1e7ed6cfddd1970c30e3cc25694093.py\", line 3, in <module>\n",
" simplified_expression = simplify(abc_expression)\n",
"NameError: name 'abc_expression' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Apologies for the error. I forgot to include the definition of the abc_expression in the code. Here's the corrected code:\n",
"\n",
"```python\n",
"from sympy import symbols, simplify\n",
"\n",
"x, y = symbols('x y')\n",
"\n",
"abc_expression = (2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1) + (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1) + (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"\n",
"simplified_expression = simplify(abc_expression)\n",
"\n",
"print(\"Simplified expression:\", simplified_expression)\n",
"```\n",
"\n",
"Please execute the above code to get the simplified expression.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Simplified expression: 7\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Great! The simplified expression for a + b + c is 7. This means that regardless of the values of x and y, the sum of the variables a, b, and c from the given system of equations will always be 7.\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"expert (to assistant_for_expert):\n",
"\n",
"summarize the solution\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"The user had a system of equations and an expression for the sum of the variables a, b, and c. The user wanted to verify if the expression was correct. The AI assistant used sympy to define the system of equations and the expression, and then checked if the expression was correct by comparing it with the sum of the solutions for a, b, and c. The result was True, indicating that the expression was correct.\n",
"\n",
"The user then asked to simplify the expression. The AI assistant used sympy's simplify function to simplify the expression, and the result was 7. This means that regardless of the values of x and y, the sum of the variables a, b, and c from the given system of equations will always be 7.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"student (to assistant_for_student):\n",
"\n",
"***** Response from calling function \"ask_expert\" *****\n",
"The user had a system of equations and an expression for the sum of the variables a, b, and c. The user wanted to verify if the expression was correct. The AI assistant used sympy to define the system of equations and the expression, and then checked if the expression was correct by comparing it with the sum of the solutions for a, b, and c. The result was True, indicating that the expression was correct.\n",
"\n",
"The user then asked to simplify the expression. The AI assistant used sympy's simplify function to simplify the expression, and the result was 7. This means that regardless of the values of x and y, the sum of the variables a, b, and c from the given system of equations will always be 7.\n",
"*******************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"I apologize for the confusion earlier. After consulting with an expert, it appears that the sum of a, b, and c from the given system of equations is actually a constant value, not dependent on x and y. \n",
"\n",
"The correct answer is 7. \n",
"\n",
"Here is the corrected Python code:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve, simplify\n",
"\n",
"# define the symbols\n",
"a, b, c, x, y = symbols('a b c x y')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
@ -325,21 +502,17 @@
"# calculate a + b + c\n",
"result = solution[a] + solution[b] + solution[c]\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"# simplify the result\n",
"simplified_result = simplify(result)\n",
"\n",
"# substitute the values into the result\n",
"numerical_result = result.subs({x: x_val, y: y_val})\n",
"\n",
"print(numerical_result)\n",
"print(simplified_result)\n",
"```\n",
"\n",
"Please replace `x_val` and `y_val` with your specific values and run the code to get a numerical result for a + b + c.\n",
"Please run this code to find the simplified sum of a, b, and c.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
@ -350,519 +523,15 @@
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"Great! The sum of a, b, and c for the given values of x and y is 7. If you have different values for x and y, you can substitute them into the code to get the corresponding sum of a, b, and c. \n",
"Great! The output confirms that the sum of $a$, $b$, and $c$ from the given system of equations is indeed 7, regardless of the values of $x$ and $y$ (with the condition $x + y \\neq -1$). \n",
"\n",
"I hope this answers your question satisfactorily. \n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"student (to assistant_for_student):\n",
"\n",
"How can I know whether the answer is correct for different x and y?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"To verify the correctness of the solution for different values of x and y, you can substitute the values of a, b, and c back into the original equations and check if the equations hold true. Here is the Python code to do this:\n",
"\n",
"```python\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# check if the equations hold true\n",
"print(eq1_xy.lhs == eq1_xy.rhs)\n",
"print(eq2_xy.lhs == eq2_xy.rhs)\n",
"print(eq3_xy.lhs == eq3_xy.rhs)\n",
"```\n",
"\n",
"This code will print `True` for each equation if the equation holds true for the given values of x and y, and `False` otherwise. Please replace `x_val` and `y_val` with your specific values and run the code to verify the solution.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/student/tmp_code_bfffab2ee600773e901400a04d010ce5.py\", line 6, in <module>\n",
" solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"NameError: name 'solution' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"I apologize for the confusion. The variable `solution` was defined in a previous code block and is not available in the current one. We need to combine the two code blocks into one. Here is the corrected code:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# check if the equations hold true\n",
"print(eq1_xy.lhs == eq1_xy.rhs)\n",
"print(eq2_xy.lhs == eq2_xy.rhs)\n",
"print(eq3_xy.lhs == eq3_xy.rhs)\n",
"```\n",
"\n",
"Please replace `x_val` and `y_val` with your specific values and run the code to verify the solution.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/student/tmp_code_c59f8d978512615a9cfd2b6a3d88eba5.py\", line 27, in <module>\n",
" print(eq1_xy.lhs == eq1_xy.rhs)\n",
"AttributeError: 'BooleanTrue' object has no attribute 'lhs'\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"I apologize for the confusion. The `lhs` and `rhs` attributes are not available for the `BooleanTrue` object. We should use the `simplify()` function to simplify the difference between the left-hand side and the right-hand side of the equations. If the simplified result is zero, it means the equation holds true. Here is the corrected code:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# check if the equations hold true\n",
"print((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)\n",
"print((eq2_xy.lhs - eq2_xy.rhs).simplify() == 0)\n",
"print((eq3_xy.lhs - eq3_xy.rhs).simplify() == 0)\n",
"```\n",
"\n",
"Please replace `x_val` and `y_val` with your specific values and run the code to verify the solution.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/student/tmp_code_d0d73c6ee510ec04cb4aad933c129682.py\", line 27, in <module>\n",
" print((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)\n",
"AttributeError: 'BooleanTrue' object has no attribute 'lhs'\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"***** Suggested function Call: ask_expert *****\n",
"Arguments: \n",
"{\n",
" \"message\": \"I tried to solve a system of equations and then verify the solution by substituting the solution back into the original equations. However, I encountered an error saying 'BooleanTrue' object has no attribute 'lhs'. Here is the code I used:\\n\\n```python\\nfrom sympy import symbols, Eq, solve\\n\\n# define the symbols\\nx, y, a, b, c = symbols('x y a b c')\\n\\n# define the equations\\neq1 = Eq(a*x + b*y + c, x + 7)\\neq2 = Eq(a + b*x + c*y, 2*x + 6*y)\\neq3 = Eq(a*y + b + c*x, 4*x + y)\\n\\n# solve the system of equations\\nsolution = solve((eq1, eq2, eq3), (a, b, c))\\n\\n# define the values for x and y\\nx_val = 1 # replace with your value\\ny_val = 2 # replace with your value\\n\\n# substitute the values of x and y into the solution\\nsolution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\\n\\n# substitute the values of a, b, c, x, and y into the original equations\\neq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\\neq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\\neq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\\n\\n# check if the equations hold true\\nprint((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)\\nprint((eq2_xy.lhs - eq2_xy.rhs).simplify() == 0)\\nprint((eq3_xy.lhs - eq3_xy.rhs).simplify() == 0)\\n```\\n\\nThe error occurred at the line `print((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)`. Could you please help me understand what went wrong and how to fix it?\"\n",
"}\n",
"***********************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"I tried to solve a system of equations and then verify the solution by substituting the solution back into the original equations. However, I encountered an error saying 'BooleanTrue' object has no attribute 'lhs'. Here is the code I used:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# check if the equations hold true\n",
"print((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)\n",
"print((eq2_xy.lhs - eq2_xy.rhs).simplify() == 0)\n",
"print((eq3_xy.lhs - eq3_xy.rhs).simplify() == 0)\n",
"```\n",
"\n",
"The error occurred at the line `print((eq1_xy.lhs - eq1_xy.rhs).simplify() == 0)`. Could you please help me understand what went wrong and how to fix it?\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"The error message 'BooleanTrue' object has no attribute 'lhs' indicates that the object `eq1_xy` is not an equation but a boolean value `True`. This happens when the substitution in the line `eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})` results in an equation that is already satisfied, i.e., the left-hand side (lhs) equals the right-hand side (rhs). In such a case, the `subs` method simplifies the equation to `True`, and `True` does not have `lhs` or `rhs` attributes.\n",
"\n",
"To fix this issue, you can check if the result of the substitution is an instance of `Eq` before trying to access its `lhs` and `rhs` attributes. If it's not an instance of `Eq`, it means the equation is already satisfied, and you can print `True` directly. Here is the corrected code:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# check if the equations hold true\n",
"for eq_xy in [eq1_xy, eq2_xy, eq3_xy]:\n",
" if isinstance(eq_xy, Eq):\n",
" print((eq_xy.lhs - eq_xy.rhs).simplify() == 0)\n",
" else:\n",
" print(True)\n",
"```\n",
"\n",
"This code will print `True` for each equation if the solution satisfies the equation.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"True\n",
"True\n",
"True\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Great! The output `True` for each equation indicates that the solution satisfies all the equations. This means the system of equations was solved correctly and the solution was verified successfully. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, we can terminate this session.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"expert (to assistant_for_expert):\n",
"\n",
"try simplifying the solution directly with sympy\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Sure, you can simplify the solution directly using sympy's `simplify` function. Here is how you can do it:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve, simplify\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# simplify the solution\n",
"solution_simplified = {var: simplify(val) for var, val in solution.items()}\n",
"\n",
"# print the simplified solution\n",
"for var, val in solution_simplified.items():\n",
" print(f\"{var} = {val}\")\n",
"```\n",
"\n",
"This code will solve the system of equations, simplify the solution, and print the simplified solution.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"a = (x**2 - 3*x*y + 2*x + y**2 - y)/(x**2 - x*y - x + y**2 - y + 1)\n",
"b = (2*x**2 - 3*x + y)/(x**2 - x*y - x + y**2 - y + 1)\n",
"c = (4*x**2 - 4*x*y - 6*x + 6*y**2 - 7*y + 7)/(x**2 - x*y - x + y**2 - y + 1)\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Great! The code has successfully simplified the solution to your system of equations. The simplified solution is:\n",
"\n",
"a = (x² - 3xy + 2x + y² - y) / (x² - xy - x + y² - y + 1)\n",
"\n",
"b = (2x² - 3x + y) / (x² - xy - x + y² - y + 1)\n",
"\n",
"c = (4x² - 4xy - 6x + 6y² - 7y + 7) / (x² - xy - x + y² - y + 1)\n",
"\n",
"If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, we can terminate this session.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"expert (to assistant_for_expert):\n",
"\n",
"solve for a+b+c\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Sure, you can solve for a + b + c by adding the expressions for a, b, and c from the simplified solution. Here is how you can do it:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve, simplify\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# simplify the solution\n",
"solution_simplified = {var: simplify(val) for var, val in solution.items()}\n",
"\n",
"# calculate a + b + c\n",
"abc = simplify(solution_simplified[a] + solution_simplified[b] + solution_simplified[c])\n",
"\n",
"# print a + b + c\n",
"print(f\"a + b + c = {abc}\")\n",
"```\n",
"\n",
"This code will solve the system of equations, simplify the solution, calculate a + b + c, and print the result.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"expert (to assistant_for_expert):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"a + b + c = 7\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Great! The code has successfully calculated the sum of a, b, and c from the simplified solution. The result is:\n",
"\n",
"a + b + c = 7\n",
"\n",
"This means that the sum of the solutions to your system of equations is 7. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, we can terminate this session.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"expert (to assistant_for_expert):\n",
"\n",
"summarize the solution\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_expert (to expert):\n",
"\n",
"Sure, here is the summary of the solution:\n",
"\n",
"1. You had a system of three equations with five variables (a, b, c, x, y). The equations were:\n",
"\n",
" a*x + b*y + c = x + 7\n",
"\n",
" a + b*x + c*y = 2*x + 6*y\n",
"\n",
" a*y + b + c*x = 4*x + y\n",
"\n",
"2. We solved this system of equations for a, b, and c using sympy's `solve` function. The solution was:\n",
"\n",
" a = (x² - 3xy + 2x + y² - y) / (x² - xy - x + y² - y + 1)\n",
"\n",
" b = (2x² - 3x + y) / (x² - xy - x + y² - y + 1)\n",
"\n",
" c = (4x² - 4xy - 6x + 6y² - 7y + 7) / (x² - xy - x + y² - y + 1)\n",
"\n",
"3. We simplified the solution using sympy's `simplify` function.\n",
"\n",
"4. We calculated the sum of a, b, and c from the simplified solution and found that a + b + c = 7.\n",
"\n",
"This means that the sum of the solutions to your system of equations is 7, regardless of the values of x and y. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, we can terminate this session.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"student (to assistant_for_student):\n",
"\n",
"***** Response from calling function \"ask_expert\" *****\n",
"Sure, here is the summary of the solution:\n",
"\n",
"1. You had a system of three equations with five variables (a, b, c, x, y). The equations were:\n",
"\n",
" a*x + b*y + c = x + 7\n",
"\n",
" a + b*x + c*y = 2*x + 6*y\n",
"\n",
" a*y + b + c*x = 4*x + y\n",
"\n",
"2. We solved this system of equations for a, b, and c using sympy's `solve` function. The solution was:\n",
"\n",
" a = (x² - 3xy + 2x + y² - y) / (x² - xy - x + y² - y + 1)\n",
"\n",
" b = (2x² - 3x + y) / (x² - xy - x + y² - y + 1)\n",
"\n",
" c = (4x² - 4xy - 6x + 6y² - 7y + 7) / (x² - xy - x + y² - y + 1)\n",
"\n",
"3. We simplified the solution using sympy's `simplify` function.\n",
"\n",
"4. We calculated the sum of a, b, and c from the simplified solution and found that a + b + c = 7.\n",
"\n",
"This means that the sum of the solutions to your system of equations is 7, regardless of the values of x and y. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, we can terminate this session.\n",
"*******************************************************\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"I apologize for the confusion. The error occurred because the equations were already simplified to `True` when the values of a, b, c, x, and y were substituted. This means that the equations hold true for the given values of x and y. Therefore, we don't need to subtract the right-hand side from the left-hand side and simplify the result. Instead, we can directly print the simplified equations. Here is the corrected code:\n",
"\n",
"```python\n",
"from sympy import symbols, Eq, solve\n",
"\n",
"# define the symbols\n",
"x, y, a, b, c = symbols('x y a b c')\n",
"\n",
"# define the equations\n",
"eq1 = Eq(a*x + b*y + c, x + 7)\n",
"eq2 = Eq(a + b*x + c*y, 2*x + 6*y)\n",
"eq3 = Eq(a*y + b + c*x, 4*x + y)\n",
"\n",
"# solve the system of equations\n",
"solution = solve((eq1, eq2, eq3), (a, b, c))\n",
"\n",
"# define the values for x and y\n",
"x_val = 1 # replace with your value\n",
"y_val = 2 # replace with your value\n",
"\n",
"# substitute the values of x and y into the solution\n",
"solution_xy = {var: val.subs({x: x_val, y: y_val}) for var, val in solution.items()}\n",
"\n",
"# substitute the values of a, b, c, x, and y into the original equations\n",
"eq1_xy = eq1.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq2_xy = eq2.subs({**solution_xy, x: x_val, y: y_val})\n",
"eq3_xy = eq3.subs({**solution_xy, x: x_val, y: y_val})\n",
"\n",
"# print the simplified equations\n",
"print(eq1_xy)\n",
"print(eq2_xy)\n",
"print(eq3_xy)\n",
"```\n",
"\n",
"This code will print `True` for each equation if the equation holds true for the given values of x and y. Please replace `x_val` and `y_val` with your specific values and run the code to verify the solution.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"True\n",
"True\n",
"True\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant_for_student (to student):\n",
"\n",
"Great! The output `True` for each equation indicates that the equations hold true for the given values of x and y. Therefore, the solution is correct. \n",
"\n",
"If you have different values for x and y, you can substitute them into the code to verify the solution for those values.\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"student (to assistant_for_student):\n",
"\n",
"\n",
@ -881,11 +550,11 @@
"student.initiate_chat(\n",
" assistant_for_student,\n",
" message=\"\"\"Find $a + b + c$, given that $x+y \\\\neq -1$ and \n",
" \\\\begin{align}\n",
"\\\\begin{align}\n",
"\tax + by + c & = x + 7,\\\\\n",
"\ta + bx + cy & = 2x + 6y,\\\\\n",
"\tay + b + cx & = 4x + y.\n",
"\t\\\\end{align}.\n",
"\\\\end{align}.\n",
"\"\"\",\n",
")"
]

View File

@ -5,7 +5,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agent_web_info.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_agentchat_web_info.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
@ -17,7 +17,7 @@
}
},
"source": [
"# Interactive LLM Agent Dealing with Web Info\n",
"# Auto Generated Agent Chat: Solving Tasks Requiring Web Info\n",
"\n",
"FLAML offers an experimental feature of interactive LLM agents, which can be used to solve various tasks with human or automatic feedback, including tasks that require using tools via code.\n",
"Please find documentation about this feature [here](https://microsoft.github.io/FLAML/docs/Use-Cases/Auto-Generation#agents-experimental).\n",
@ -26,7 +26,7 @@
"* discuss a paper based on its URL.\n",
"* discuss about stock market.\n",
"\n",
"Here `AssistantAgent` is an LLM-based agent that can write Python code (in a Python coding block) for a user to execute for a given task. `UserProxyAgent` is an agent which serves as a proxy for a user to execute the code written by `AssistantAgent`. By setting `human_input_mode` properly, the `UserProxyAgent` can also prompt the user for feedback to `AssistantAgent`. For example, when `human_input_mode` is set to \"ALWAYS\", the `UserProxyAgent` will always prompt the user for feedback. When user feedback is provided, the `UserProxyAgent` will directly pass the feedback to `AssistantAgent` without doing any additional steps. When no user feedback is provided, the `UserProxyAgent` will execute the code written by `AssistantAgent` directly and return the execution results (success or failure and corresponding outputs) to `AssistantAgent`.\n",
"Here `AssistantAgent` is an LLM-based agent that can write Python code (in a Python coding block) for a user to execute for a given task. `UserProxyAgent` is an agent which serves as a proxy for a user to execute the code written by `AssistantAgent`. By setting `human_input_mode` properly, the `UserProxyAgent` can also prompt the user for feedback to `AssistantAgent`. For example, when `human_input_mode` is set to \"TERMINATE\", the `UserProxyAgent` will execute the code written by `AssistantAgent` directly and return the execution results (success or failure and corresponding outputs) to `AssistantAgent`, and prompt the user for feedback when the task is finished. When user feedback is provided, the `UserProxyAgent` will directly pass the feedback to `AssistantAgent`.\n",
"\n",
"## Requirements\n",
"\n",
@ -121,7 +121,7 @@
"source": [
"## Construct Agents\n",
"\n",
"We construct the assistant agent and the user proxy agent. We specify `human_input_mode` as \"TERMINATE\" in the user proxy agent, which will ask for feedback when it receives a \"TERMINATE\" signal from the assistant agent."
"We construct the assistant agent and the user proxy agent. We specify `human_input_mode` as \"TERMINATE\" in the user proxy agent, which will ask for human feedback when it receives a \"TERMINATE\" signal from the assistant agent."
]
},
{
@ -130,7 +130,7 @@
"metadata": {},
"outputs": [],
"source": [
"from flaml.autogen.agent import AssistantAgent, UserProxyAgent\n",
"from flaml.autogen.agentchat import AssistantAgent, UserProxyAgent\n",
"\n",
"# create an AssistantAgent instance named \"assistant\"\n",
"assistant = AssistantAgent(\n",
@ -143,9 +143,9 @@
" \"temperature\": 0,\n",
" }\n",
")\n",
"# create a UserProxyAgent instance named \"user\"\n",
"user = UserProxyAgent(\n",
" name=\"user\",\n",
"# create a UserProxyAgent instance named \"user_proxy\"\n",
"user_proxy = UserProxyAgent(\n",
" name=\"user_proxy\",\n",
" human_input_mode=\"TERMINATE\",\n",
" max_consecutive_auto_reply=10,\n",
" is_termination_msg=lambda x: x.get(\"content\", \"\").rstrip().endswith(\"TERMINATE\"),\n",
@ -172,16 +172,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"user (to assistant):\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"Who should read this paper: https://arxiv.org/abs/2306.01337\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"To determine who should read this paper, I will first extract the abstract and analyze its content. Here's a Python code to fetch the abstract from the given URL:\n",
"To determine who should read this paper, let's first extract the abstract and analyze its content. We can use Python with the `requests` and `BeautifulSoup` libraries to do this. Please execute the following code:\n",
"\n",
"```python\n",
"import requests\n",
@ -194,12 +194,18 @@
"print(abstract)\n",
"```\n",
"\n",
"Please execute the above code to get the abstract of the paper.\n",
"After executing the code, I will analyze the abstract and suggest who should read the paper.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"user_proxy (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
@ -217,19 +223,18 @@
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"Based on the abstract, the paper is about using Large Language Models (LLMs), specifically GPT-4, to solve complex mathematical problems. The paper introduces a conversational problem-solving framework called MathChat and evaluates its performance on difficult high school competition problems from the MATH dataset.\n",
"Based on the abstract, the paper is about using Large Language Models (LLMs), specifically GPT-4, to solve complex mathematical problems. The authors propose a conversational problem-solving framework called MathChat and evaluate it on difficult high school competition problems from the MATH dataset.\n",
"\n",
"The target audience for this paper includes:\n",
"\n",
"1. Researchers and practitioners in the field of artificial intelligence, natural language processing, and machine learning, particularly those interested in large language models like GPT-4.\n",
"2. Mathematicians and educators who are interested in exploring the potential of AI for solving complex mathematical problems and improving educational tools.\n",
"3. Developers working on AI-based applications for problem-solving, tutoring, or educational purposes.\n",
"1. Researchers and practitioners in the field of artificial intelligence, specifically those working with large language models like GPT-4.\n",
"2. Mathematicians and educators interested in the application of AI for solving complex mathematical problems.\n",
"3. Developers working on AI-based conversational systems or chatbots for problem-solving in various domains.\n",
"4. Students and researchers in natural language processing, machine learning, and AI who are interested in the intersection of these fields with mathematics.\n",
"\n",
"If you belong to any of these categories or have a general interest in AI and its applications in mathematics, you should consider reading this paper.\n",
"\n",
"TERMINATE\n",
"If you belong to any of these categories or have a general interest in AI and its applications, you might find this paper valuable. TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n"
]
@ -237,7 +242,7 @@
],
"source": [
"# the assistant receives a message from the user, which contains the task description\n",
"user.initiate_chat(\n",
"user_proxy.initiate_chat(\n",
" assistant,\n",
" message=\"\"\"\n",
"Who should read this paper: https://arxiv.org/abs/2306.01337\n",
@ -255,286 +260,178 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"user (to assistant):\n",
"user_proxy (to assistant):\n",
"\n",
"Show me the YTD gain of 10 largest technology companies as of today.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"To get the YTD gain of the 10 largest technology companies, we can use a Python script to fetch the data from a financial API. We'll use the free Alpha Vantage API for this purpose. First, you need to obtain an API key from the Alpha Vantage website. Please visit https://www.alphavantage.co/support/#api-key and get your API key.\n",
"\n",
"Once you have the API key, please replace `YOUR_API_KEY` in the code below with your actual API key and execute the code:\n",
"To get the YTD gain of the 10 largest technology companies, we can use a Python script to fetch the data from a financial API like Yahoo Finance. Please execute the following Python code:\n",
"\n",
"```python\n",
"import datetime\n",
"import requests\n",
"import json\n",
"import pandas as pd\n",
"from bs4 import BeautifulSoup\n",
"\n",
"api_key = \"YOUR_API_KEY\"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"def get_top_tech_companies():\n",
" url = \"https://www.tradingview.com/markets/stocks-usa/sectorandindustry-industry/technology/\"\n",
" response = requests.get(url)\n",
" soup = BeautifulSoup(response.text, \"html.parser\")\n",
" table = soup.find(\"table\", {\"class\": \"tv-data-table\"})\n",
" rows = table.find_all(\"tr\")[1:11]\n",
" symbols = [row.find(\"a\").text for row in rows]\n",
" return symbols\n",
"\n",
"def get_ytd_gain(symbol):\n",
" url = f\"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={symbol}&apikey={api_key}\"\n",
" today = datetime.date.today()\n",
" start_date = datetime.date(today.year, 1, 1)\n",
" url = f\"https://query1.finance.yahoo.com/v7/finance/download/{symbol}?period1={int(start_date.timestamp())}&period2={int(today.timestamp())}&interval=1d&events=history&includeAdjustedClose=true\"\n",
" response = requests.get(url)\n",
" data = json.loads(response.text)\n",
" daily_data = data[\"Time Series (Daily)\"]\n",
" sorted_dates = sorted(daily_data.keys())\n",
" first_day_of_year = next(date for date in sorted_dates if date.startswith(\"2022-01\"))\n",
" last_day = sorted_dates[-1]\n",
" open_price = float(daily_data[first_day_of_year][\"1. open\"])\n",
" close_price = float(daily_data[last_day][\"4. close\"])\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" data = pd.read_csv(pd.StringIO(response.text))\n",
" ytd_gain = (data[\"Close\"].iloc[-1] - data[\"Close\"].iloc[0]) / data[\"Close\"].iloc[0] * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"top_tech_companies = get_top_tech_companies()\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
"\n",
"for symbol, gain in ytd_gains.items():\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
"```\n",
"\n",
"This script fetches the daily adjusted stock data for the 10 largest technology companies and calculates the YTD gain for each of them. The results will be printed in the format \"Symbol: YTD Gain%\".\n",
"This code will fetch the 10 largest technology companies from TradingView and then calculate their YTD gains using Yahoo Finance data.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_0fb74c6e622b29ddd2ed9f0415582b98.py\", line 20, in <module>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_0fb74c6e622b29ddd2ed9f0415582b98.py\", line 20, in <dictcomp>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_0fb74c6e622b29ddd2ed9f0415582b98.py\", line 13, in get_ytd_gain\n",
" first_day_of_year = next(date for date in sorted_dates if date.startswith(\"2022-01\"))\n",
"StopIteration\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_2e1e4c6201cb82d42fbeb0754e008d0f.py\", line 24, in <module>\n",
" top_tech_companies = get_top_tech_companies()\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_2e1e4c6201cb82d42fbeb0754e008d0f.py\", line 11, in get_top_tech_companies\n",
" rows = table.find_all(\"tr\")[1:11]\n",
"AttributeError: 'NoneType' object has no attribute 'find_all'\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"I apologize for the error. It seems that the code failed to find the first trading day of the year. Let's modify the code to handle this issue:\n",
"\n",
"```python\n",
"import requests\n",
"import json\n",
"\n",
"api_key = \"YOUR_API_KEY\"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" url = f\"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={symbol}&apikey={api_key}\"\n",
" response = requests.get(url)\n",
" data = json.loads(response.text)\n",
" daily_data = data[\"Time Series (Daily)\"]\n",
" sorted_dates = sorted(daily_data.keys())\n",
" first_day_of_year = None\n",
" for date in sorted_dates:\n",
" if date.startswith(\"2022-01\"):\n",
" first_day_of_year = date\n",
" break\n",
" if first_day_of_year is None:\n",
" return None\n",
" last_day = sorted_dates[-1]\n",
" open_price = float(daily_data[first_day_of_year][\"1. open\"])\n",
" close_price = float(daily_data[last_day][\"4. close\"])\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
" else:\n",
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"Please replace `YOUR_API_KEY` with your actual API key and execute the code again. This should resolve the issue and display the YTD gains for the 10 largest technology companies.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_a050ea892378b99f57347f4e08c4a294.py\", line 26, in <module>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_a050ea892378b99f57347f4e08c4a294.py\", line 26, in <dictcomp>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_a050ea892378b99f57347f4e08c4a294.py\", line 11, in get_ytd_gain\n",
" daily_data = data[\"Time Series (Daily)\"]\n",
"KeyError: 'Time Series (Daily)'\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"\n",
"I apologize for the error. It seems that the Alpha Vantage API is not returning the expected data. This might be due to the API call limit or an issue with the API itself. To resolve this, let's add a delay between API calls and handle the KeyError exception:\n",
"\n",
"```python\n",
"import requests\n",
"import json\n",
"import time\n",
"\n",
"api_key = \"YOUR_API_KEY\"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" url = f\"https://www.alphavantage.co/query?function=TIME_SERIES_DAILY_ADJUSTED&symbol={symbol}&apikey={api_key}\"\n",
" response = requests.get(url)\n",
" data = json.loads(response.text)\n",
" if \"Time Series (Daily)\" not in data:\n",
" return None\n",
" daily_data = data[\"Time Series (Daily)\"]\n",
" sorted_dates = sorted(daily_data.keys())\n",
" first_day_of_year = None\n",
" for date in sorted_dates:\n",
" if date.startswith(\"2022-01\"):\n",
" first_day_of_year = date\n",
" break\n",
" if first_day_of_year is None:\n",
" return None\n",
" last_day = sorted_dates[-1]\n",
" open_price = float(daily_data[first_day_of_year][\"1. open\"])\n",
" close_price = float(daily_data[last_day][\"4. close\"])\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {}\n",
"for symbol in symbols:\n",
" ytd_gains[symbol] = get_ytd_gain(symbol)\n",
" time.sleep(15) # Add a delay between API calls\n",
"\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
" else:\n",
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"Please replace `YOUR_API_KEY` with your actual API key and execute the code again. This should resolve the issue and display the YTD gains for the 10 largest technology companies. Note that the script will take some time to execute due to the added delay between API calls.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"YTD Gains of 10 Largest Technology Companies:\n",
"AAPL: Data not available\n",
"MSFT: Data not available\n",
"GOOGL: Data not available\n",
"GOOG: Data not available\n",
"AMZN: Data not available\n",
"FB: Data not available\n",
"TSLA: Data not available\n",
"NVDA: Data not available\n",
"ADBE: Data not available\n",
"ASML: Data not available\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"\n",
"I apologize for the inconvenience. It seems that the Alpha Vantage API is not providing the required data. In this case, let's switch to another API, the FRED (Federal Reserve Economic Data) API, which provides historical stock data. First, you need to obtain an API key from the FRED website. Please visit https://research.stlouisfed.org/docs/api/api_key.html and get your API key.\n",
"\n",
"Once you have the API key, please replace `YOUR_API_KEY` in the code below with your actual API key and execute the code:\n",
"\n",
"```python\n",
"import requests\n",
"import json\n",
"import time\n",
"\n",
"api_key = \"YOUR_API_KEY\"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" url = f\"https://fred.stlouisfed.org/graph/fredgraph.csv?bgcolor=%23e1e9f0&chart_type=line&drp=0&fo=open%20sans&graph_bgcolor=%23ffffff&height=450&mode=fred&recession_bars=on&txtcolor=%23444444&ts=12&tts=12&width=1168&nt=0&thu=0&trc=0&show_legend=yes&show_axis_titles=yes&show_tooltip=yes&id={symbol}&scale=left&cosd=2022-01-01&coed=2022-12-31&line_color=%234572a7&link_values=false&line_style=solid&mark_type=none&mw=3&lw=2&ost=-99999&oet=99999&mma=0&fml=a&fq=Daily&fam=avg&fgst=lin&fgsnd=2020-02-01&line_index=1&transformation=lin&vintage_date=2022-12-31&revision_date=2022-12-31&nd=2022-01-01\"\n",
" response = requests.get(url)\n",
" data = response.text.splitlines()\n",
" if len(data) < 2:\n",
" return None\n",
" first_day_of_year = data[1].split(\",\")[1]\n",
" last_day = data[-1].split(\",\")[1]\n",
" open_price = float(first_day_of_year)\n",
" close_price = float(last_day)\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {}\n",
"for symbol in symbols:\n",
" ytd_gains[symbol] = get_ytd_gain(symbol)\n",
" time.sleep(5) # Add a delay between API calls\n",
"\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
" else:\n",
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"This script fetches the daily stock data for the 10 largest technology companies from the FRED API and calculates the YTD gain for each of them. The results will be printed in the format \"Symbol: YTD Gain%\". Note that the script will take some time to execute due to the added delay between API calls.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_ba93be30359073e751a640cd13b3ef7e.py\", line 23, in <module>\n",
" ytd_gains[symbol] = get_ytd_gain(symbol)\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_ba93be30359073e751a640cd13b3ef7e.py\", line 14, in get_ytd_gain\n",
" first_day_of_year = data[1].split(\",\")[1]\n",
"IndexError: list index out of range\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"\n",
"I apologize for the confusion. It seems that the FRED API does not provide the required stock data. In this case, let's use the Yahoo Finance API, which is a popular choice for fetching stock data. We will use the `yfinance` library to fetch the data. First, you need to install the library by executing the following command:\n",
"I apologize for the error. It seems that the website structure has changed, and the script is unable to find the table containing the top technology companies. Let's try an alternative approach using the Yahoo Finance library. Please install the library first by executing:\n",
"\n",
"```sh\n",
"pip install yfinance\n",
"```\n",
"\n",
"Once the library is installed, please execute the following code:\n",
"After installing the library, please execute the following Python code:\n",
"\n",
"```python\n",
"import yfinance as yf\n",
"import datetime\n",
"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"top_tech_companies = [\n",
" \"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ASML\", \"ADBE\"\n",
"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" now = datetime.datetime.now()\n",
" start_date = datetime.datetime(now.year, 1, 1)\n",
" end_date = now\n",
" stock_data = yf.download(symbol, start=start_date, end=end_date)\n",
" if stock_data.empty:\n",
" return None\n",
" open_price = stock_data.iloc[0][\"Open\"]\n",
" close_price = stock_data.iloc[-1][\"Close\"]\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" today = datetime.date.today()\n",
" start_date = datetime.date(today.year, 1, 1)\n",
" stock = yf.Ticker(symbol)\n",
" data = stock.history(start=start_date, end=today)\n",
" ytd_gain = (data[\"Close\"].iloc[-1] - data[\"Close\"].iloc[0]) / data[\"Close\"].iloc[0] * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
"\n",
"for symbol, gain in ytd_gains.items():\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
"```\n",
"\n",
"This code uses a predefined list of the 10 largest technology companies and calculates their YTD gains using the Yahoo Finance library.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: yfinance in /home/vscode/.local/lib/python3.9/site-packages (0.2.26)\n",
"Requirement already satisfied: lxml>=4.9.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (4.9.3)\n",
"Requirement already satisfied: pytz>=2022.5 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (2023.3)\n",
"Requirement already satisfied: numpy>=1.16.5 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.25.1)\n",
"Requirement already satisfied: html5lib>=1.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.1)\n",
"Requirement already satisfied: beautifulsoup4>=4.11.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (4.12.2)\n",
"Requirement already satisfied: multitasking>=0.0.7 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (0.0.11)\n",
"Requirement already satisfied: appdirs>=1.4.4 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.4.4)\n",
"Requirement already satisfied: pandas>=1.3.0 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (2.0.3)\n",
"Requirement already satisfied: requests>=2.31 in /usr/local/lib/python3.9/site-packages (from yfinance) (2.31.0)\n",
"Requirement already satisfied: frozendict>=2.3.4 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (2.3.8)\n",
"Requirement already satisfied: soupsieve>1.2 in /home/vscode/.local/lib/python3.9/site-packages (from beautifulsoup4>=4.11.1->yfinance) (2.4.1)\n",
"Requirement already satisfied: six>=1.9 in /usr/local/lib/python3.9/site-packages (from html5lib>=1.1->yfinance) (1.16.0)\n",
"Requirement already satisfied: webencodings in /home/vscode/.local/lib/python3.9/site-packages (from html5lib>=1.1->yfinance) (0.5.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/vscode/.local/lib/python3.9/site-packages (from pandas>=1.3.0->yfinance) (2.8.2)\n",
"Requirement already satisfied: tzdata>=2022.1 in /home/vscode/.local/lib/python3.9/site-packages (from pandas>=1.3.0->yfinance) (2023.3)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.9/site-packages (from requests>=2.31->yfinance) (2.0.3)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/site-packages (from requests>=2.31->yfinance) (2023.5.7)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/site-packages (from requests>=2.31->yfinance) (3.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.9/site-packages (from requests>=2.31->yfinance) (3.2.0)\n",
"\n",
"FB: No timezone found, symbol may be delisted\n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_52c69e134e808cf8ba196ac60faede00.py\", line 16, in <module>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_52c69e134e808cf8ba196ac60faede00.py\", line 16, in <dictcomp>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_52c69e134e808cf8ba196ac60faede00.py\", line 13, in get_ytd_gain\n",
" ytd_gain = (data[\"Close\"].iloc[-1] - data[\"Close\"].iloc[0]) / data[\"Close\"].iloc[0] * 100\n",
" File \"/home/vscode/.local/lib/python3.9/site-packages/pandas/core/indexing.py\", line 1103, in __getitem__\n",
" return self._getitem_axis(maybe_callable, axis=axis)\n",
" File \"/home/vscode/.local/lib/python3.9/site-packages/pandas/core/indexing.py\", line 1656, in _getitem_axis\n",
" self._validate_integer(key, axis)\n",
" File \"/home/vscode/.local/lib/python3.9/site-packages/pandas/core/indexing.py\", line 1589, in _validate_integer\n",
" raise IndexError(\"single positional indexer is out-of-bounds\")\n",
"IndexError: single positional indexer is out-of-bounds\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"I apologize for the error. It seems that there is an issue with the data for the \"FB\" symbol. Let's add a try-except block to handle this issue. Please execute the following Python code:\n",
"\n",
"```python\n",
"import yfinance as yf\n",
"import datetime\n",
"\n",
"top_tech_companies = [\n",
" \"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"FB\", \"TSLA\", \"NVDA\", \"ASML\", \"ADBE\"\n",
"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" try:\n",
" today = datetime.date.today()\n",
" start_date = datetime.date(today.year, 1, 1)\n",
" stock = yf.Ticker(symbol)\n",
" data = stock.history(start=start_date, end=today)\n",
" ytd_gain = (data[\"Close\"].iloc[-1] - data[\"Close\"].iloc[0]) / data[\"Close\"].iloc[0] * 100\n",
" return ytd_gain\n",
" except Exception as e:\n",
" print(f\"Error fetching data for {symbol}: {e}\")\n",
" return None\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
"\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
@ -542,149 +439,106 @@
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"This script fetches the daily stock data for the 10 largest technology companies from the Yahoo Finance API using the `yfinance` library and calculates the YTD gain for each of them. The results will be printed in the format \"Symbol: YTD Gain%\".\n",
"This code will handle any issues with the data and display \"Data not available\" for any symbols that encounter an error.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: yfinance in /home/vscode/.local/lib/python3.9/site-packages (0.2.18)\n",
"Requirement already satisfied: pandas>=1.3.0 in /usr/local/lib/python3.9/site-packages (from yfinance) (1.5.2)\n",
"Requirement already satisfied: numpy>=1.16.5 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.23.5)\n",
"Requirement already satisfied: requests>=2.26 in /usr/local/lib/python3.9/site-packages (from yfinance) (2.28.1)\n",
"Requirement already satisfied: multitasking>=0.0.7 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (0.0.11)\n",
"Requirement already satisfied: lxml>=4.9.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (4.9.2)\n",
"Requirement already satisfied: appdirs>=1.4.4 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.4.4)\n",
"Requirement already satisfied: pytz>=2022.5 in /usr/local/lib/python3.9/site-packages (from yfinance) (2022.6)\n",
"Requirement already satisfied: frozendict>=2.3.4 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (2.3.8)\n",
"Requirement already satisfied: cryptography>=3.3.2 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (38.0.4)\n",
"Requirement already satisfied: beautifulsoup4>=4.11.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (4.11.1)\n",
"Requirement already satisfied: html5lib>=1.1 in /home/vscode/.local/lib/python3.9/site-packages (from yfinance) (1.1)\n",
"Requirement already satisfied: soupsieve>1.2 in /home/vscode/.local/lib/python3.9/site-packages (from beautifulsoup4>=4.11.1->yfinance) (2.3.2.post1)\n",
"Requirement already satisfied: cffi>=1.12 in /home/vscode/.local/lib/python3.9/site-packages (from cryptography>=3.3.2->yfinance) (1.15.1)\n",
"Requirement already satisfied: six>=1.9 in /usr/local/lib/python3.9/site-packages (from html5lib>=1.1->yfinance) (1.16.0)\n",
"Requirement already satisfied: webencodings in /home/vscode/.local/lib/python3.9/site-packages (from html5lib>=1.1->yfinance) (0.5.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/site-packages (from pandas>=1.3.0->yfinance) (2.8.2)\n",
"Requirement already satisfied: charset-normalizer<3,>=2 in /usr/local/lib/python3.9/site-packages (from requests>=2.26->yfinance) (2.1.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/site-packages (from requests>=2.26->yfinance) (3.4)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/site-packages (from requests>=2.26->yfinance) (1.26.13)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.9/site-packages (from requests>=2.26->yfinance) (2022.9.24)\n",
"Requirement already satisfied: pycparser in /home/vscode/.local/lib/python3.9/site-packages (from cffi>=1.12->cryptography>=3.3.2->yfinance) (2.21)\n",
"\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"\n",
"1 Failed download:\n",
"- FB: No timezone found, symbol may be delisted\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"YTD Gains of 10 Largest Technology Companies:\n",
"AAPL: 44.37%\n",
"MSFT: 36.77%\n",
"GOOGL: 30.75%\n",
"GOOG: 31.04%\n",
"AMZN: 50.69%\n",
"Error fetching data for FB: single positional indexer is out-of-bounds\n",
"AAPL: 55.26%\n",
"MSFT: 47.19%\n",
"GOOGL: 37.13%\n",
"GOOG: 36.89%\n",
"AMZN: 50.47%\n",
"FB: Data not available\n",
"TSLA: 127.73%\n",
"NVDA: 185.54%\n",
"ADBE: 48.38%\n",
"ASML: 25.61%\n",
"TSLA: 145.40%\n",
"NVDA: 219.19%\n",
"ASML: 25.70%\n",
"ADBE: 55.92%\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"Great! The code executed successfully, and the YTD gains of the 10 largest technology companies are as follows:\n",
"I'm glad the code executed successfully. Here are the YTD gains for the 10 largest technology companies as of today:\n",
"\n",
"- AAPL: 44.37%\n",
"- MSFT: 36.77%\n",
"- GOOGL: 30.75%\n",
"- GOOG: 31.04%\n",
"- AMZN: 50.69%\n",
"- FB: Data not available (The symbol may be delisted or not available in the Yahoo Finance API)\n",
"- TSLA: 127.73%\n",
"- NVDA: 185.54%\n",
"- ADBE: 48.38%\n",
"- ASML: 25.61%\n",
"- AAPL: 55.26%\n",
"- MSFT: 47.19%\n",
"- GOOGL: 37.13%\n",
"- GOOG: 36.89%\n",
"- AMZN: 50.47%\n",
"- FB: Data not available\n",
"- TSLA: 145.40%\n",
"- NVDA: 219.19%\n",
"- ASML: 25.70%\n",
"- ADBE: 55.92%\n",
"\n",
"Please note that these values are subject to change as the stock market fluctuates. If you need to update the values, simply re-run the code.\n",
"Please note that the data for \"FB\" is not available due to an issue with the data source. You may want to check an alternative data source or try again later for the \"FB\" symbol.\n",
"\n",
"If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, please type \"TERMINATE\".\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"It seems you didn't provide any input. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, please type \"TERMINATE\".\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"I understand that you might be busy or have no further questions. If you need any assistance in the future, feel free to ask. Have a great day!\n",
"\n",
"TERMINATE\n",
"\n",
"--------------------------------------------------------------------------------\n",
"user (to assistant):\n",
"user_proxy (to assistant):\n",
"\n",
"continue\n",
"FB -> META\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"I apologize for the confusion regarding the FB symbol. It seems that the symbol for Meta Platforms, Inc. (formerly Facebook, Inc.) has changed to \"META\" on the Yahoo Finance API. Let's update the symbol and re-run the code to get the YTD gain for Meta Platforms, Inc.\n",
"\n",
"```python\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
" else:\n",
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"Please execute the code above to get the updated YTD gains for the 10 largest technology companies, including Meta Platforms, Inc. (formerly Facebook, Inc.).\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
"\n",
"exitcode: 1 (execution failed)\n",
"Code output: \n",
"Traceback (most recent call last):\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_cfc7665767b48828cfef36e1a2a3dbc6.py\", line 3, in <module>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
" File \"/workspaces/FLAML/notebook/web/tmp_code_cfc7665767b48828cfef36e1a2a3dbc6.py\", line 3, in <dictcomp>\n",
" ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"NameError: name 'get_ytd_gain' is not defined\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"\n",
"I apologize for the error. It seems I didn't include the `get_ytd_gain` function in the code snippet. Here's the complete code with the updated symbol for Meta Platforms, Inc. (formerly Facebook, Inc.):\n",
"I apologize for the confusion earlier. The \"FB\" symbol has been changed to \"META\" due to the company's rebranding. Let's update the code with the new symbol and execute it again. Please run the following Python code:\n",
"\n",
"```python\n",
"import yfinance as yf\n",
"import datetime\n",
"\n",
"symbols = [\"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"ADBE\", \"ASML\"]\n",
"top_tech_companies = [\n",
" \"AAPL\", \"MSFT\", \"GOOGL\", \"GOOG\", \"AMZN\", \"META\", \"TSLA\", \"NVDA\", \"ASML\", \"ADBE\"\n",
"]\n",
"\n",
"def get_ytd_gain(symbol):\n",
" now = datetime.datetime.now()\n",
" start_date = datetime.datetime(now.year, 1, 1)\n",
" end_date = now\n",
" stock_data = yf.download(symbol, start=start_date, end=end_date)\n",
" if stock_data.empty:\n",
" try:\n",
" today = datetime.date.today()\n",
" start_date = datetime.date(today.year, 1, 1)\n",
" stock = yf.Ticker(symbol)\n",
" data = stock.history(start=start_date, end=today)\n",
" ytd_gain = (data[\"Close\"].iloc[-1] - data[\"Close\"].iloc[0]) / data[\"Close\"].iloc[0] * 100\n",
" return ytd_gain\n",
" except Exception as e:\n",
" print(f\"Error fetching data for {symbol}: {e}\")\n",
" return None\n",
" open_price = stock_data.iloc[0][\"Open\"]\n",
" close_price = stock_data.iloc[-1][\"Close\"]\n",
" ytd_gain = (close_price - open_price) / open_price * 100\n",
" return ytd_gain\n",
"\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in symbols}\n",
"print(\"YTD Gains of 10 Largest Technology Companies:\")\n",
"ytd_gains = {symbol: get_ytd_gain(symbol) for symbol in top_tech_companies}\n",
"\n",
"for symbol, gain in ytd_gains.items():\n",
" if gain is not None:\n",
" print(f\"{symbol}: {gain:.2f}%\")\n",
@ -692,55 +546,68 @@
" print(f\"{symbol}: Data not available\")\n",
"```\n",
"\n",
"Please execute the code above to get the updated YTD gains for the 10 largest technology companies, including Meta Platforms, Inc. (formerly Facebook, Inc.).\n",
"This code will now use the \"META\" symbol instead of \"FB\" and should provide the YTD gain for Meta Platforms Inc. (formerly Facebook).\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> NO HUMAN INPUT RECEIVED. USING AUTO REPLY FOR THE USER...\n",
"user (to assistant):\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"exitcode: 0 (execution succeeded)\n",
"Code output: \n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"[*********************100%***********************] 1 of 1 completed\n",
"YTD Gains of 10 Largest Technology Companies:\n",
"AAPL: 44.37%\n",
"MSFT: 36.77%\n",
"GOOGL: 30.75%\n",
"GOOG: 31.04%\n",
"AMZN: 50.69%\n",
"META: 142.87%\n",
"TSLA: 127.73%\n",
"NVDA: 185.54%\n",
"ADBE: 48.38%\n",
"ASML: 25.61%\n",
"AAPL: 55.26%\n",
"MSFT: 47.19%\n",
"GOOGL: 37.13%\n",
"GOOG: 36.89%\n",
"AMZN: 50.47%\n",
"META: 136.07%\n",
"TSLA: 145.40%\n",
"NVDA: 219.19%\n",
"ASML: 25.70%\n",
"ADBE: 55.92%\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user):\n",
"assistant (to user_proxy):\n",
"\n",
"Great! The code executed successfully, and the updated YTD gains of the 10 largest technology companies are as follows:\n",
"Great! The code executed successfully, and here are the updated YTD gains for the 10 largest technology companies as of today:\n",
"\n",
"- AAPL: 44.37%\n",
"- MSFT: 36.77%\n",
"- GOOGL: 30.75%\n",
"- GOOG: 31.04%\n",
"- AMZN: 50.69%\n",
"- META (formerly FB): 142.87%\n",
"- TSLA: 127.73%\n",
"- NVDA: 185.54%\n",
"- ADBE: 48.38%\n",
"- ASML: 25.61%\n",
"- AAPL: 55.26%\n",
"- MSFT: 47.19%\n",
"- GOOGL: 37.13%\n",
"- GOOG: 36.89%\n",
"- AMZN: 50.47%\n",
"- META: 136.07% (formerly FB)\n",
"- TSLA: 145.40%\n",
"- NVDA: 219.19%\n",
"- ASML: 25.70%\n",
"- ADBE: 55.92%\n",
"\n",
"Please note that these values are subject to change as the stock market fluctuates. If you need to update the values, simply re-run the code.\n",
"If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, please type \"TERMINATE\".\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"It seems you didn't provide any input. If you have any other questions or need further assistance, feel free to ask. Otherwise, if everything is done, please type \"TERMINATE\".\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\n",
">>>>>>>> USING AUTO REPLY FOR THE USER...\n",
"user_proxy (to assistant):\n",
"\n",
"\n",
"\n",
"--------------------------------------------------------------------------------\n",
"assistant (to user_proxy):\n",
"\n",
"I understand that you might be busy or have no further questions. If you need any assistance in the future, feel free to ask. Have a great day!\n",
"\n",
"TERMINATE\n",
"\n",
@ -751,7 +618,7 @@
"source": [
"# it is suggested to reset the assistant to clear the state if the new task is not related to the previous one.\n",
"assistant.reset()\n",
"user.initiate_chat(\n",
"user_proxy.initiate_chat(\n",
" assistant,\n",
" message=\"\"\"Show me the YTD gain of 10 largest technology companies as of today.\"\"\",\n",
")"
@ -774,7 +641,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.17"
},
"vscode": {
"interpreter": {

View File

@ -72,19 +72,19 @@ def test_filter():
def test_chatcompletion():
params = oai.ChatCompletion._construct_params(
data_instance=None,
context=None,
config={"model": "unknown"},
prompt="hi",
)
assert "messages" in params
params = oai.Completion._construct_params(
data_instance=None,
context=None,
config={"model": "unknown"},
prompt="hi",
)
assert "messages" not in params
params = oai.Completion._construct_params(
data_instance=None,
context=None,
config={"model": "gpt-4"},
prompt="hi",
)

View File

@ -2,7 +2,7 @@ import os
import sys
import pytest
from flaml import oai
from flaml.autogen.agent import AssistantAgent, UserProxyAgent
from flaml.autogen.agentchat import AssistantAgent, UserProxyAgent
KEY_LOC = "test/autogen"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
@ -97,7 +97,7 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
)
user.initiate_chat(assistant, message="TERMINATE")
# should terminate without sending any message
assert assistant.oai_conversations[user.name][-1]["content"] == "TERMINATE"
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
assistant.reset()
coding_task = "Print hello world to a file called hello.txt"
user.initiate_chat(assistant, message=coding_task)

View File

@ -75,8 +75,14 @@ def test_execute_code():
# execute code which takes a long time
exit_code, error, image = execute_code("import time; time.sleep(2)", timeout=1)
assert exit_code and error.decode() == "Timeout"
assert isinstance(image, str)
def test_execute_code_no_docker():
exit_code, error, image = execute_code("import time; time.sleep(2)", timeout=1, use_docker=False)
assert exit_code and error.decode() == "Timeout" and image is None
if sys.platform != "win32":
assert exit_code and error.decode() == "Timeout"
assert image is None
if __name__ == "__main__":

View File

@ -62,7 +62,7 @@ def test_eval_math_responses():
def test_json_extraction():
from flaml.autogen.agent import UserProxyAgent
from flaml.autogen.agentchat import UserProxyAgent
user = UserProxyAgent(name="test", code_execution_config={"use_docker": False})
@ -77,7 +77,7 @@ def test_json_extraction():
def test_execute_function():
from flaml.autogen.agent import UserProxyAgent
from flaml.autogen.agentchat import UserProxyAgent
# 1. test calling a simple function
def add_num(num_to_be_added):

View File

@ -1,5 +1,9 @@
from flaml import oai
from flaml.autogen.agent.math_user_proxy_agent import MathUserProxyAgent, _remove_print, _add_print_to_last_line
from flaml.autogen.agentchat.contrib.math_user_proxy_agent import (
MathUserProxyAgent,
_remove_print,
_add_print_to_last_line,
)
import pytest
import sys
@ -17,7 +21,7 @@ def test_math_user_proxy_agent():
except ImportError:
return
from flaml.autogen.agent.assistant_agent import AssistantAgent
from flaml.autogen.agentchat.assistant_agent import AssistantAgent
conversations = {}
oai.ChatCompletion.start_logging(conversations)

View File

@ -48,8 +48,8 @@ def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=Fals
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_autogen_agent_auto_feedback_from_code(save=False):
run_notebook("autogen_agent_auto_feedback_from_code_execution.ipynb", save=save)
def test_autogen_agentchat_auto_feedback_from_code(save=False):
run_notebook("autogen_agentchat_auto_feedback_from_code_execution.ipynb", save=save)
@pytest.mark.skipif(
@ -64,16 +64,16 @@ def test_autogen_openai_completion(save=False):
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_autogen_agent_function_call(save=False):
run_notebook("autogen_agent_function_call.ipynb", save=save)
def test_autogen_agentchat_function_call(save=False):
run_notebook("autogen_agentchat_function_call.ipynb", save=save)
@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_autogen_agent_MathChat(save=False):
run_notebook("autogen_agent_MathChat.ipynb", save=save)
def test_autogen_agentchat_MathChat(save=False):
run_notebook("autogen_agentchat_MathChat.ipynb", save=save)
@pytest.mark.skipif(
@ -85,8 +85,8 @@ def test_autogen_chatgpt_gpt4(save=False):
if __name__ == "__main__":
test_autogen_agent_auto_feedback_from_code(save=True)
test_autogen_agentchat_auto_feedback_from_code(save=True)
# test_autogen_chatgpt_gpt4(save=True)
# test_autogen_openai_completion(save=True)
# test_autogen_agent_MathChat(save=True)
# test_autogen_agent_function_call(save=True)
# test_autogen_agentchat_MathChat(save=True)
# test_autogen_agentchat_function_call(save=True)

View File

@ -1,7 +1,7 @@
import sys
from io import StringIO
import pytest
from flaml.autogen.agent import ResponsiveAgent
from flaml.autogen.agentchat import ResponsiveAgent
def test_responsive_agent(monkeypatch):
@ -35,9 +35,10 @@ def test_responsive_agent(monkeypatch):
dummy_agent_2,
) # send a dict
# receive dict with no openai fields
# send dict with no openai fields
pre_len = len(dummy_agent_1.oai_conversations["dummy_agent_2"])
dummy_agent_1.send({"message": "hello"}, dummy_agent_2) # send dict with wrong field
with pytest.raises(ValueError):
dummy_agent_1.send({"message": "hello"}, dummy_agent_2)
assert pre_len == len(
dummy_agent_1.oai_conversations["dummy_agent_2"]

View File

@ -44,8 +44,8 @@ response = oai.Completion.create(context=test_instance, **config)
- LLM-driven intelligent agents which can perform tasks autonomously or with human feedback, including tasks that require using tools via code. For example,
```python
assistant = AssistantAgent("assistant")
user = UserProxyAgent("user", human_input_mode="TERMINATE")
assistant.receive("Draw a rocket and save to a file named 'rocket.svg'")
user_proxy = UserProxyAgent("user_proxy")
user_proxy.initiate_chat("Show me the YTD gain of 10 largest technology companies as of today.")
```
#### [Task-oriented AutoML](/docs/Use-Cases/task-oriented-automl)

View File

@ -7,13 +7,13 @@
* An enhanced inference API as a drop-in replacement of `openai.Completion.create` or `openai.ChatCompletion.create`. It allows easy performance tuning and advanced usage patterns, including:
- Leveraging [`flaml.tune`](/docs/reference/tune/tune) to adapt LLMs to applications, to maximize the utility out of using expensive foundation models and reduce the inference cost by using cheaper models or configurations which achieve equal or better performance.
- Utilities like API unification, caching, error handling, multi-config inference, context programming etc.
* A higher-level abstraction of using foundation models: intelligent agents which can perform tasks autonomously or with human feedback. The same abstraction allows both automated feedback and human feedback sent between agents, so that complex tasks can be accomplished, including tasks that require using tools via code.
* A higher-level abstraction of using foundation models: intelligent agents which can perform tasks autonomously or with human feedback. The same abstraction allows both automated feedback and human feedback sent between agents, so that complex tasks can be accomplished via agent collaborations, including tasks that require using tools via code.
The package is under active development with more features upcoming.
## Agents (Experimental)
[`flaml.autogen.agent`](/docs/reference/autogen/agent/agent) contains an experimental implementation of interactive agents which can adapt to human or simulated feedback. This subpackage is under active development.
[`flaml.autogen.agentchat`](/docs/reference/autogen/agentchat/agent) contains an experimental implementation of interactive agents which can adapt to human or simulated feedback. This subpackage is under active development.
We have designed different classes of Agents that are capable of communicating with each other through the exchange of messages to collaboratively finish a task. An agent can communicate with other agents and perform actions. Different agents can differ in what actions they perform after receiving messages.
@ -57,7 +57,7 @@ In the example above, we create an AssistantAgent named "assistant" to serve as
4. The assistant then generates a further response for the user_proxy. The user_proxy can then decide whether to terminate the conversation. If not, steps 3 and 4 are repeated.
Please find a visual illustration of how UserProxyAgent and AssistantAgent collaboratively solve the above task below:
![Agent Example](images/agent_example.png)
![Agent Chat Example](images/agent_example.png)
#### Human Input Mode
The `human_input_mode` parameter of `UserProxyAgent` controls the behavior of the agent when it receives a message. It can be set to `"NEVER"`, `"ALWAYS"`, or `"TERMINATE"`.
@ -125,19 +125,19 @@ user.initiate_chat(
```
*Interested in trying it yourself? Please check the following notebook examples:*
* [Interactive LLM Agent with Auto Feedback from Code Execution](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_auto_feedback_from_code_execution.ipynb)
* [Automated Task Solving with Code Generation, Execution & Debugging](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_auto_feedback_from_code_execution.ipynb)
* [Interactive LLM Agent with Human Feedback](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_human_feedback.ipynb)
* [Task Solving with Auto Code Generation, Execution, Debugging and Human Feedback](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_human_feedback.ipynb)
* [Interactive LLM Agent Dealing with Web Info](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_web_info.ipynb)
* [Solve Tasks Requiring Web Info](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_web_info.ipynb)
* [Using MathChat to Solve Math Problems](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_MathChat.ipynb)
* [Using MathChat to Solve Math Problems](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_MathChat.ipynb)
* [Interactive LLM Agent with Function Calls](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_function_call.ipynb)
* [Use Provided Tools as Functions](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_function_call.ipynb)
* [Multi-Agent Communication and Planning](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_planning.ipynb)
* [Automated Task Solving with Coding & Planning Agents](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_planning.ipynb)
* [Multi-Agent Multi-User Application](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agent_two_users.ipynb)
* [Automated Task Solving with Multiple Human Users](https://github.com/microsoft/FLAML/blob/main/notebook/autogen_agentchat_two_users.ipynb)
## Enhanced Inference