mirror of https://github.com/microsoft/autogen.git
Adding a test function for OpenAI completion in flaml (#951)
* improve max_valid_n and doc * Update README.md Co-authored-by: Li Jiang <lijiang1@microsoft.com> * add support for chatgpt * notebook * newline at end of file * chatgpt notebook * ChatGPT in Azure * doc * math * warning, timeout, log file name * handle import error * doc update; default value * paper * doc * docstr * eval_func * add a test func in completion * update notebook * update math notebook * improve notebok * lint and handle exception * flake8 * exception in test * add agg_method * NameError * refactor * Update flaml/integrations/oai/completion.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * Update flaml/integrations/oai/completion.py Co-authored-by: Chi Wang <wang.chi@microsoft.com> * add example * merge files from oai_eval_test * Revert "merge files from oai_eval_test" This reverts commit 1e6a550f913bb94df6e9680934ccb7175d00702e. * merge * save results to notebook_output * update version and cache * update doc * save nb cell results to file * fix typo in model name * code improvements * improve docstr * docstr * docstr on the Returns of test --------- Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Li Jiang <lijiang1@microsoft.com> Co-authored-by: Susan Xueqing Liu <liususan091219@users.noreply.github.com>
This commit is contained in:
parent
05c5f8f426
commit
45641000c0
|
@ -40,6 +40,7 @@ jobs:
|
|||
pip install coverage pytest datasets nbconvert nbformat ipykernel
|
||||
coverage run -a -m pytest test/openai
|
||||
coverage xml
|
||||
cat "$(pwd)/test/openai/executed_openai_notebook_output.txt"
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
|
|
|
@ -118,16 +118,17 @@ class Completion:
|
|||
cls.cache_path = f"{cache_path}/{seed}"
|
||||
|
||||
@classmethod
|
||||
def _get_response(cls, config: dict, eval_only=False):
|
||||
def _get_response(cls, config: dict, eval_only=False, use_cache=True):
|
||||
"""Get the response from the openai api call.
|
||||
|
||||
Try cache first. If not found, call the openai api. If the api call fails, retry after retry_time.
|
||||
"""
|
||||
key = get_key(config)
|
||||
response = cls._cache.get(key, None)
|
||||
if response is not None and (response != -1 or not eval_only):
|
||||
# print("using cached response")
|
||||
return response
|
||||
if use_cache:
|
||||
response = cls._cache.get(key, None)
|
||||
if response is not None and (response != -1 or not eval_only):
|
||||
# print("using cached response")
|
||||
return response
|
||||
openai_completion = (
|
||||
openai.ChatCompletion
|
||||
if config["model"] in cls.chat_models
|
||||
|
@ -215,7 +216,27 @@ class Completion:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def eval(cls, config: dict, prune=True, eval_only=False):
|
||||
def _get_prompt_messages_from_config(cls, model, config):
|
||||
prompt, messages = None, None
|
||||
if model in cls.chat_models:
|
||||
# either "prompt" should be in config (for being compatible with non-chat models)
|
||||
# or "messages" should be in config (for tuning chat models only)
|
||||
prompt = config.get("prompt")
|
||||
messages = config.get("messages")
|
||||
# either prompt or messages should be in config, but not both
|
||||
assert (prompt is None) != (
|
||||
messages is None
|
||||
), "Either prompt or messages should be in config for chat models."
|
||||
if prompt is None:
|
||||
messages = cls._messages[messages]
|
||||
else:
|
||||
prompt = cls._prompts[prompt]
|
||||
else:
|
||||
prompt = cls._prompts[config["prompt"]]
|
||||
return prompt, messages
|
||||
|
||||
@classmethod
|
||||
def _eval(cls, config: dict, prune=True, eval_only=False):
|
||||
"""Evaluate the given config as the hyperparameter setting for the openai api call.
|
||||
|
||||
Args:
|
||||
|
@ -242,22 +263,7 @@ class Completion:
|
|||
max_tokens = config.get(
|
||||
"max_tokens", np.inf if model in cls.chat_models else 16
|
||||
)
|
||||
# default value in OpenAI
|
||||
if model in cls.chat_models:
|
||||
# either "prompt" should be in config (for being compatible with non-chat models)
|
||||
# or "messages" should be in config (for tuning chat models only)
|
||||
prompt = config.get("prompt")
|
||||
messages = config.get("messages")
|
||||
# either prompt or messages should be in config, but not both
|
||||
assert (prompt is None) != (
|
||||
messages is None
|
||||
), "Either prompt or messages should be in config for chat models."
|
||||
if prompt is None:
|
||||
messages = cls._messages[messages]
|
||||
else:
|
||||
prompt = cls._prompts[prompt]
|
||||
else:
|
||||
prompt = cls._prompts[config["prompt"]]
|
||||
prompt, messages = cls._get_prompt_messages_from_config(model, config)
|
||||
stop = cls._stops and cls._stops[config["stop"]]
|
||||
target_output_tokens = None
|
||||
if not cls.avg_input_tokens:
|
||||
|
@ -309,33 +315,7 @@ class Completion:
|
|||
f"num_completions={num_completions}, data instance={i}"
|
||||
)
|
||||
data_i = data[i]
|
||||
if prompt is None:
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": m["role"],
|
||||
"content": m["content"].format(**data_i)
|
||||
if isinstance(m["content"], str)
|
||||
else m["content"](data_i),
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
elif model in cls.chat_models:
|
||||
# convert prompt to messages
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt.format(**data_i)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_i),
|
||||
},
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
else:
|
||||
params["prompt"] = (
|
||||
prompt.format(**data_i)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_i)
|
||||
)
|
||||
params = cls._construct_params(data_i, params, prompt, messages)
|
||||
response = cls._get_response(params, eval_only)
|
||||
if response == -1: # rate limit error, treat as invalid
|
||||
cls._update_invalid_n(
|
||||
|
@ -481,6 +461,7 @@ class Completion:
|
|||
"""Tune the parameters for the OpenAI API call.
|
||||
|
||||
TODO: support parallel tuning with ray or spark.
|
||||
TODO: support agg_method as in test
|
||||
|
||||
Args:
|
||||
data (list): The list of data points.
|
||||
|
@ -490,20 +471,20 @@ class Completion:
|
|||
The function should take a list of responses and a data point as input,
|
||||
and return a dict of metrics. For example,
|
||||
|
||||
```python
|
||||
def eval_func(responses, **data):
|
||||
solution = data["solution"]
|
||||
success_list = []
|
||||
n = len(responses)
|
||||
for i in range(n):
|
||||
response = responses[i]
|
||||
succeed = is_equiv_chain_of_thought(response, solution)
|
||||
success_list.append(succeed)
|
||||
return {
|
||||
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||
"success": any(s for s in success_list),
|
||||
}
|
||||
```
|
||||
```python
|
||||
def eval_func(responses, **data):
|
||||
solution = data["solution"]
|
||||
success_list = []
|
||||
n = len(responses)
|
||||
for i in range(n):
|
||||
response = responses[i]
|
||||
succeed = is_equiv_chain_of_thought(response, solution)
|
||||
success_list.append(succeed)
|
||||
return {
|
||||
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||
"success": any(s for s in success_list),
|
||||
}
|
||||
```
|
||||
|
||||
log_file_name (str, optional): The log file.
|
||||
inference_budget (float, optional): The inference budget.
|
||||
|
@ -613,7 +594,7 @@ class Completion:
|
|||
logger.setLevel(logging_level)
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
analysis = tune.run(
|
||||
cls.eval,
|
||||
cls._eval,
|
||||
search_alg=search_alg,
|
||||
num_samples=num_samples,
|
||||
log_file_name=log_file_name,
|
||||
|
@ -650,36 +631,199 @@ class Completion:
|
|||
"""
|
||||
if ERROR:
|
||||
raise ERROR
|
||||
params = config.copy()
|
||||
prompt = config.get("prompt")
|
||||
if "messages" in config:
|
||||
params["messages"] = [
|
||||
{
|
||||
k: v.format(**context) if isinstance(v, str) else v(context)
|
||||
for k, v in message.items()
|
||||
}
|
||||
for message in config["messages"]
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
elif config["model"] in cls.chat_models:
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt.format(**context)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(context),
|
||||
}
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
else:
|
||||
params["prompt"] = (
|
||||
prompt.format(**context) if isinstance(prompt, str) else prompt(context)
|
||||
)
|
||||
params = cls._construct_params(context, config)
|
||||
if use_cache:
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
return cls._get_response(params)
|
||||
return cls.openai_completion_class.create(**params)
|
||||
|
||||
@classmethod
|
||||
def _construct_params(cls, data_instance, config, prompt=None, messages=None):
|
||||
params = config.copy()
|
||||
model = config["model"]
|
||||
prompt = config.get("prompt") if prompt is None else prompt
|
||||
messages = config.get("messages") if messages is None else messages
|
||||
# either "prompt" should be in config (for being compatible with non-chat models)
|
||||
# or "messages" should be in config (for tuning chat models only)
|
||||
if prompt is None and model in cls.chat_models:
|
||||
if messages is None:
|
||||
raise ValueError(
|
||||
"Either prompt or messages should be in config for chat models."
|
||||
)
|
||||
if prompt is None:
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": m["role"],
|
||||
"content": m["content"].format(**data_instance)
|
||||
if isinstance(m["content"], str)
|
||||
else m["content"](data_instance),
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
elif model in cls.chat_models:
|
||||
# convert prompt to messages
|
||||
if isinstance(prompt, str):
|
||||
prompt_msg = prompt.format(**data_instance)
|
||||
else:
|
||||
prompt_msg = prompt(data_instance)
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_msg
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_instance),
|
||||
},
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
else:
|
||||
params["prompt"] = (
|
||||
prompt.format(**data_instance)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_instance)
|
||||
)
|
||||
return params
|
||||
|
||||
@classmethod
|
||||
def test(
|
||||
cls,
|
||||
data,
|
||||
config,
|
||||
eval_func=None,
|
||||
use_cache=True,
|
||||
agg_method="avg",
|
||||
return_responses_and_per_instance_result=False,
|
||||
seed=41,
|
||||
cache_path=".cache",
|
||||
):
|
||||
"""Evaluate the responses created with the config for the OpenAI API call.
|
||||
|
||||
Args:
|
||||
data (list): The list of test data points.
|
||||
config (dict): Hyperparameter setting for the openai api call.
|
||||
eval_func (Callable): The evaluation function for responses per data instance.
|
||||
The function should take a list of responses and a data point as input,
|
||||
and return a dict of metrics. You need to either provide a valid callable
|
||||
eval_func; or do not provide one (set None) but call the test function after
|
||||
calling the tune function in which a eval_func is provided.
|
||||
In the latter case we will use the eval_func provided via tune function.
|
||||
Defaults to None.
|
||||
|
||||
```python
|
||||
def eval_func(responses, **data):
|
||||
solution = data["solution"]
|
||||
success_list = []
|
||||
n = len(responses)
|
||||
for i in range(n):
|
||||
response = responses[i]
|
||||
succeed = is_equiv_chain_of_thought(response, solution)
|
||||
success_list.append(succeed)
|
||||
return {
|
||||
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||
"success": any(s for s in success_list),
|
||||
}
|
||||
```
|
||||
use_cache (bool, Optional): Whether to use cached responses. Defaults to True.
|
||||
agg_method (str, Callable or a dict of Callable): Result aggregation method (across
|
||||
multiple instances) for each of the metrics. Defaults to 'avg'.
|
||||
An example agg_method in str:
|
||||
|
||||
```python
|
||||
agg_method = 'median'
|
||||
```
|
||||
An example agg_method in a Callable:
|
||||
|
||||
```python
|
||||
agg_method = np.median
|
||||
```
|
||||
|
||||
An example agg_method in a dict of Callable:
|
||||
|
||||
```python
|
||||
agg_method={'median_success': np.median, 'avg_success': np.mean}
|
||||
```
|
||||
|
||||
return_responses_and_per_instance_result (bool): Whether to also return responses
|
||||
and per instance results in addition to the aggregated results.
|
||||
seed (int): Random seed for the evaluation. Defaults to 41.
|
||||
cache_path (str): Path to the cache directory. Defaults to '.cache'.
|
||||
If a cache directory does not exist, it will be created, otherwise use the existing one.
|
||||
Returns:
|
||||
None in case of rate limit error or when a valid eval_func is not provided in either test or tune;
|
||||
Otherwise, a dict of aggregated results, responses and per instance results if `return_responses_and_per_instance_result` is True;
|
||||
Otherwise, a dict of aggregated results (responses and per instance results are not returned).
|
||||
"""
|
||||
model = config["model"]
|
||||
result_agg, responses_list, result_list = {}, [], []
|
||||
metric_keys = None
|
||||
cls.set_cache(seed, cache_path)
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
for i, data_i in enumerate(data):
|
||||
logger.info(f"evaluating data instance {i}")
|
||||
params = cls._construct_params(data_i, config)
|
||||
response = cls._get_response(
|
||||
params, eval_only=True, use_cache=use_cache
|
||||
)
|
||||
if response == -1: # rate limit error, treat as invalid
|
||||
return None
|
||||
# evaluate the quality of the responses
|
||||
responses = (
|
||||
[r["message"]["content"].rstrip() for r in response["choices"]]
|
||||
if model in cls.chat_models
|
||||
else [r["text"].rstrip() for r in response["choices"]]
|
||||
)
|
||||
|
||||
if eval_func is not None:
|
||||
metrics = eval_func(responses, **data_i)
|
||||
elif hasattr(cls, "_eval_func"):
|
||||
metrics = cls._eval_func(responses, **data_i)
|
||||
else:
|
||||
logger.warning(
|
||||
"Please either provide a valid eval_func or do the test after the tune function is called"
|
||||
)
|
||||
return
|
||||
if not metric_keys:
|
||||
metric_keys = []
|
||||
for k in metrics.keys():
|
||||
try:
|
||||
_ = float(metrics[k])
|
||||
metric_keys.append(k)
|
||||
except ValueError:
|
||||
pass
|
||||
result_list.append(metrics)
|
||||
if return_responses_and_per_instance_result:
|
||||
responses_list.append(responses)
|
||||
if isinstance(agg_method, str):
|
||||
if agg_method in ["avg", "average"]:
|
||||
for key in metric_keys:
|
||||
result_agg[key] = np.mean([r[key] for r in result_list])
|
||||
elif agg_method == "median":
|
||||
for key in metric_keys:
|
||||
result_agg[key] = np.median([r[key] for r in result_list])
|
||||
else:
|
||||
logger.warning(
|
||||
f"Aggregation method {agg_method} not supported. Please write your own aggregation method as a callable(s)."
|
||||
)
|
||||
elif callable(agg_method):
|
||||
for key in metric_keys:
|
||||
result_agg[key] = agg_method([r[key] for r in result_list])
|
||||
elif isinstance(agg_method, dict):
|
||||
for key in metric_keys:
|
||||
metric_agg_method = agg_method[key]
|
||||
assert callable(
|
||||
metric_agg_method
|
||||
), "please provide a callable for each metric"
|
||||
result_agg[key] = metric_agg_method([r[key] for r in result_list])
|
||||
else:
|
||||
raise ValueError(
|
||||
"agg_method needs to be a string ('avg' or 'median'),\
|
||||
or a callable, or a dictionary of callable."
|
||||
)
|
||||
# should we also return the result_list and responses_list or not?
|
||||
if return_responses_and_per_instance_result:
|
||||
return result_agg, result_list, responses_list
|
||||
else:
|
||||
return result_agg
|
||||
|
||||
|
||||
class ChatCompletion(Completion):
|
||||
"""A class for OpenAI API ChatCompletion."""
|
||||
|
|
|
@ -743,7 +743,7 @@
|
|||
" # -1 means decided by the optimization budget only\n",
|
||||
" num_samples=-1,\n",
|
||||
" # model=\"chatgpt-35-turbo-0301\", # uncomment if using Azure OpenAI\n",
|
||||
" # model=\"gpt-3-turbo\", # uncomment if you don't have access to gpt-4\n",
|
||||
" # model=\"gpt-3.5-turbo\", # uncomment if you don't have access to gpt-4\n",
|
||||
" prompt=prompts, # the prompt templates to choose from\n",
|
||||
" # stop=\"###\", # the stop sequence\n",
|
||||
" logging_level=logging.INFO, # the logging level\n",
|
||||
|
@ -1271,8 +1271,9 @@
|
|||
],
|
||||
"source": [
|
||||
"responses = oai.ChatCompletion.create(context=tune_data[1], **config)\n",
|
||||
"print(responses)\n",
|
||||
"print(success_metrics([response[\"message\"][\"content\"].rstrip() for response in responses[\"choices\"]], **tune_data[1]))\n"
|
||||
"metric_results = success_metrics([response[\"message\"][\"content\"].rstrip() for response in responses[\"choices\"]], **tune_data[1])\n",
|
||||
"print(\"response on an example data instance:\", responses)\n",
|
||||
"print(\"metric_results on the example data instance:\", metric_results)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1282,7 +1283,7 @@
|
|||
"source": [
|
||||
"### Evaluate the success rate on the test data\n",
|
||||
"\n",
|
||||
"You can use flaml's `oai.ChatCompletion.eval` to evaluate the performance of an entire dataset with the tuned config. To do that you need to set `oai.ChatCompletion.data` to the data to evaluate. The following code will take a while (30 mins to 1 hour) to evaluate all the test data instances if uncommented and run. It will cost roughly $3. "
|
||||
"You can use flaml's `oai.ChatCompletion.test` to evaluate the performance of an entire dataset with the tuned config. The following code will take a while (30 mins to 1 hour) to evaluate all the test data instances if uncommented and run. It will cost roughly $3. "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1306,9 +1307,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# oai.ChatCompletion.data = test_data\n",
|
||||
"# result = oai.ChatCompletion.eval(analysis.best_config, prune=False, eval_only=True)\n",
|
||||
"# print(result)"
|
||||
"# result = oai.Completion.test(test_data, config, success_metrics)\n",
|
||||
"# print(\"performance on test data with the tuned config:\", result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1336,9 +1336,9 @@
|
|||
"# assuming you have access to gpt-4; otherwise use gpt-3.5-turbo\n",
|
||||
"# the following code will cost roughly $2 if uncommented and run.\n",
|
||||
"\n",
|
||||
"# default_config = {\"model\": 'gpt-4', \"prompt\": 0}\n",
|
||||
"# default_result = oai.ChatCompletion.eval(default_config, prune=False, eval_only=True)\n",
|
||||
"# print(default_result)"
|
||||
"# default_config = {\"model\": 'gpt-4', \"prompt\": prompts[0]}\n",
|
||||
"# default_result = oai.Completion.test(test_data, default_config, success_metrics)\n",
|
||||
"# print(\"performance on test data from gpt-4 with a default config:\", default_result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1382,11 +1382,9 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# The following evaluation costs $8 and nearly one hour if you uncomment it and run it.\n",
|
||||
"\n",
|
||||
"# config_larger = {\"model\": 'gpt-4', \"prompt\": 0, \"n\": 5}\n",
|
||||
"# default_result = oai.ChatCompletion.eval(config_larger, prune=False, eval_only=True)\n",
|
||||
"# print(default_result)"
|
||||
"# config_larger = {\"model\": 'gpt-4', \"prompt\": prompts[0], \"n\": 5}\n",
|
||||
"# default_result = oai.ChatCompletion.test(test_data, config_larger, success_metrics)\n",
|
||||
"# print(\"performance on test data from gpt-4 with a default config and n=5:\", default_result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
@ -613,7 +613,20 @@
|
|||
],
|
||||
"source": [
|
||||
"print(\"optimized config\", config)\n",
|
||||
"print(\"best result on tuning data\", analysis.best_result)"
|
||||
"print(\"best result on tuning data\", analysis.best_result)\n",
|
||||
"\n",
|
||||
"# save results to notebook_output.txt\n",
|
||||
"from flaml.version import __version__ as flaml_version\n",
|
||||
"import datetime\n",
|
||||
"results = {\"optimized config\": config, \"best result on tuning data\": analysis.best_result,}\n",
|
||||
"result_info_dict = {\"result_name\": \"integrate_openai.ipynb + optimized config and best result on tuning data\",\n",
|
||||
" \"flaml_version\": flaml_version, \n",
|
||||
" \"time\": datetime.datetime.now().strftime(\"%Y-%m-%d %H:%M:%S\"),\n",
|
||||
" \"results\": results}\n",
|
||||
"result_info = \"result name: {result_name}, flaml version: {flaml_version}, time: {time}, results: {results}\".format(**result_info_dict)\n",
|
||||
"with open(\"notebook_output.txt\", \"a\") as f:\n",
|
||||
" f.write(\"\\n\")\n",
|
||||
" f.write(result_info)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -773,8 +786,9 @@
|
|||
],
|
||||
"source": [
|
||||
"responses = oai.Completion.create(context=tune_data[1], **config)\n",
|
||||
"print(responses)\n",
|
||||
"print(success_metrics([response[\"message\"][\"content\"] if config[\"model\"] in oai.Completion.chat_models else response[\"text\"] for response in responses[\"choices\"]], **tune_data[1]))\n"
|
||||
"metric_results = success_metrics([response[\"message\"][\"content\"] if config[\"model\"] in oai.Completion.chat_models else response[\"text\"] for response in responses[\"choices\"]], **tune_data[1])\n",
|
||||
"print(\"response on an example data instance:\", responses)\n",
|
||||
"print(\"metric_results on the example data instance:\", metric_results)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -784,7 +798,7 @@
|
|||
"source": [
|
||||
"### Evaluate the success rate on the test data\n",
|
||||
"\n",
|
||||
"You can use flaml's `oai.Completion.eval` to evaluate the performance of an entire dataset with the tuned config. To do that you need to set `oai.Completion.data` to the data to evaluate. The following code will take a while to evaluate all the 144 test data instances. The cost is about $7 if you uncomment it and run it."
|
||||
"You can use flaml's `oai.Completion.test` to evaluate the performance of an entire dataset with the tuned config. The following code will take a while to evaluate all the 144 test data instances. The cost is about $7 if you uncomment it and run it."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -808,9 +822,8 @@
|
|||
}
|
||||
],
|
||||
"source": [
|
||||
"# oai.Completion.data = test_data\n",
|
||||
"# result = oai.Completion.eval(analysis.best_config, prune=False, eval_only=True)\n",
|
||||
"# print(result)"
|
||||
"result = oai.Completion.test(test_data, config, success_metrics)\n",
|
||||
"print(\"performance on test data with the tuned config:\", result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -824,9 +837,9 @@
|
|||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"display_name": "tutorial",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "tutorial"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
|
@ -838,7 +851,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.9.7"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
|
|
@ -133,14 +133,381 @@ def test_humaneval(num_samples=1):
|
|||
responses = oai.Completion.create(context=test_data[0], **config)
|
||||
print(responses)
|
||||
oai.Completion.data = test_data[:num_samples]
|
||||
result = oai.Completion.eval(analysis.best_config, prune=False, eval_only=True)
|
||||
result = oai.Completion._eval(analysis.best_config, prune=False, eval_only=True)
|
||||
print("result without pruning", result)
|
||||
result = oai.Completion.test(test_data[:num_samples], config=config)
|
||||
print(result)
|
||||
except ImportError as exc:
|
||||
print(exc)
|
||||
|
||||
|
||||
def test_math(num_samples=-1):
|
||||
from typing import Optional
|
||||
|
||||
def remove_boxed(string: str) -> Optional[str]:
|
||||
"""Source: https://github.com/hendrycks/math
|
||||
Extract the text within a \\boxed{...} environment.
|
||||
Example:
|
||||
>>> remove_boxed(\\boxed{\\frac{2}{3}})
|
||||
\\frac{2}{3}
|
||||
"""
|
||||
left = "\\boxed{"
|
||||
try:
|
||||
assert string[: len(left)] == left
|
||||
assert string[-1] == "}"
|
||||
return string[len(left) : -1]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def last_boxed_only_string(string: str) -> Optional[str]:
|
||||
"""Source: https://github.com/hendrycks/math
|
||||
Extract the last \\boxed{...} or \\fbox{...} element from a string.
|
||||
"""
|
||||
idx = string.rfind("\\boxed")
|
||||
if idx < 0:
|
||||
idx = string.rfind("\\fbox")
|
||||
if idx < 0:
|
||||
return None
|
||||
|
||||
i = idx
|
||||
right_brace_idx = None
|
||||
num_left_braces_open = 0
|
||||
while i < len(string):
|
||||
if string[i] == "{":
|
||||
num_left_braces_open += 1
|
||||
if string[i] == "}":
|
||||
num_left_braces_open -= 1
|
||||
if num_left_braces_open == 0:
|
||||
right_brace_idx = i
|
||||
break
|
||||
i += 1
|
||||
|
||||
if right_brace_idx is None:
|
||||
retval = None
|
||||
else:
|
||||
retval = string[idx : right_brace_idx + 1]
|
||||
|
||||
return retval
|
||||
|
||||
def _fix_fracs(string: str) -> str:
|
||||
"""Source: https://github.com/hendrycks/math
|
||||
Reformat fractions.
|
||||
Examples:
|
||||
>>> _fix_fracs("\\frac1b")
|
||||
\frac{1}{b}
|
||||
>>> _fix_fracs("\\frac12")
|
||||
\frac{1}{2}
|
||||
>>> _fix_fracs("\\frac1{72}")
|
||||
\frac{1}{72}
|
||||
"""
|
||||
substrs = string.split("\\frac")
|
||||
new_str = substrs[0]
|
||||
if len(substrs) > 1:
|
||||
substrs = substrs[1:]
|
||||
for substr in substrs:
|
||||
new_str += "\\frac"
|
||||
if substr[0] == "{":
|
||||
new_str += substr
|
||||
else:
|
||||
try:
|
||||
assert len(substr) >= 2
|
||||
except Exception:
|
||||
return string
|
||||
a = substr[0]
|
||||
b = substr[1]
|
||||
if b != "{":
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}{" + b + "}"
|
||||
else:
|
||||
if len(substr) > 2:
|
||||
post_substr = substr[2:]
|
||||
new_str += "{" + a + "}" + b + post_substr
|
||||
else:
|
||||
new_str += "{" + a + "}" + b
|
||||
string = new_str
|
||||
return string
|
||||
|
||||
def _fix_a_slash_b(string: str) -> str:
|
||||
"""Source: https://github.com/hendrycks/math
|
||||
Reformat fractions formatted as a/b to \\frac{a}{b}.
|
||||
Example:
|
||||
>>> _fix_a_slash_b("2/3")
|
||||
\frac{2}{3}
|
||||
"""
|
||||
if len(string.split("/")) != 2:
|
||||
return string
|
||||
a_str = string.split("/")[0]
|
||||
b_str = string.split("/")[1]
|
||||
try:
|
||||
a = int(a_str)
|
||||
b = int(b_str)
|
||||
assert string == "{}/{}".format(a, b)
|
||||
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||||
return new_string
|
||||
except Exception:
|
||||
return string
|
||||
|
||||
def _remove_right_units(string: str) -> str:
|
||||
"""Source: https://github.com/hendrycks/math"""
|
||||
if "\\text{ " in string:
|
||||
splits = string.split("\\text{ ")
|
||||
assert len(splits) == 2
|
||||
return splits[0]
|
||||
else:
|
||||
return string
|
||||
|
||||
def _fix_sqrt(string: str) -> str:
|
||||
"""Source: https://github.com/hendrycks/math"""
|
||||
if "\\sqrt" not in string:
|
||||
return string
|
||||
splits = string.split("\\sqrt")
|
||||
new_string = splits[0]
|
||||
for split in splits[1:]:
|
||||
if split[0] != "{":
|
||||
a = split[0]
|
||||
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||||
else:
|
||||
new_substr = "\\sqrt" + split
|
||||
new_string += new_substr
|
||||
return new_string
|
||||
|
||||
def _strip_string(string: str) -> str:
|
||||
"""Source: https://github.com/hendrycks/math
|
||||
Apply the reformatting helper functions above.
|
||||
"""
|
||||
# linebreaks
|
||||
string = string.replace("\n", "")
|
||||
# print(string)
|
||||
|
||||
# remove inverse spaces
|
||||
string = string.replace("\\!", "")
|
||||
# print(string)
|
||||
|
||||
# replace \\ with \
|
||||
string = string.replace("\\\\", "\\")
|
||||
# print(string)
|
||||
|
||||
# replace tfrac and dfrac with frac
|
||||
string = string.replace("tfrac", "frac")
|
||||
string = string.replace("dfrac", "frac")
|
||||
# print(string)
|
||||
|
||||
# remove \left and \right
|
||||
string = string.replace("\\left", "")
|
||||
string = string.replace("\\right", "")
|
||||
# print(string)
|
||||
|
||||
# Remove circ (degrees)
|
||||
string = string.replace("^{\\circ}", "")
|
||||
string = string.replace("^\\circ", "")
|
||||
|
||||
# remove dollar signs
|
||||
string = string.replace("\\$", "")
|
||||
|
||||
# remove units (on the right)
|
||||
string = _remove_right_units(string)
|
||||
|
||||
# remove percentage
|
||||
string = string.replace("\\%", "")
|
||||
string = string.replace(r"\%", "")
|
||||
|
||||
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||||
string = string.replace(" .", " 0.")
|
||||
string = string.replace("{.", "{0.")
|
||||
# if empty, return empty string
|
||||
if len(string) == 0:
|
||||
return string
|
||||
if string[0] == ".":
|
||||
string = "0" + string
|
||||
|
||||
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||||
if len(string.split("=")) == 2:
|
||||
if len(string.split("=")[0]) <= 2:
|
||||
string = string.split("=")[1]
|
||||
|
||||
# fix sqrt3 --> sqrt{3}
|
||||
string = _fix_sqrt(string)
|
||||
|
||||
# remove spaces
|
||||
string = string.replace(" ", "")
|
||||
|
||||
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc.
|
||||
# Even works with \frac1{72} (but not \frac{72}1).
|
||||
# Also does a/b --> \\frac{a}{b}
|
||||
string = _fix_fracs(string)
|
||||
|
||||
# manually change 0.5 --> \frac{1}{2}
|
||||
if string == "0.5":
|
||||
string = "\\frac{1}{2}"
|
||||
|
||||
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||||
string = _fix_a_slash_b(string)
|
||||
|
||||
return string
|
||||
|
||||
def get_answer(solution: Optional[str]) -> Optional[str]:
|
||||
if solution is None:
|
||||
return None
|
||||
last_boxed = last_boxed_only_string(solution)
|
||||
if last_boxed is None:
|
||||
return None
|
||||
answer = remove_boxed(last_boxed)
|
||||
if answer is None:
|
||||
return None
|
||||
return answer
|
||||
|
||||
def is_equiv(str1: Optional[str], str2: Optional[str]) -> float:
|
||||
"""Returns (as a float) whether two strings containing math are equivalent up to differences of formatting in
|
||||
- units
|
||||
- fractions
|
||||
- square roots
|
||||
- superfluous LaTeX.
|
||||
Source: https://github.com/hendrycks/math
|
||||
"""
|
||||
if str1 is None and str2 is None:
|
||||
print("WARNING: Both None")
|
||||
return 1.0
|
||||
if str1 is None or str2 is None:
|
||||
return 0.0
|
||||
|
||||
try:
|
||||
ss1 = _strip_string(str1)
|
||||
ss2 = _strip_string(str2)
|
||||
return float(ss1 == ss2)
|
||||
except Exception:
|
||||
return float(str1 == str2)
|
||||
|
||||
def is_equiv_chain_of_thought(str1: str, str2: str) -> float:
|
||||
"""Strips the solution first before calling `is_equiv`."""
|
||||
ans1 = get_answer(str1)
|
||||
ans2 = get_answer(str2)
|
||||
|
||||
return is_equiv(ans1, ans2)
|
||||
|
||||
def success_metrics(responses, solution, **args):
|
||||
"""Check if each response is correct.
|
||||
|
||||
Args:
|
||||
responses (list): The list of responses.
|
||||
solution (str): The canonical solution.
|
||||
|
||||
Returns:
|
||||
dict: The success metrics.
|
||||
"""
|
||||
success_list = []
|
||||
n = len(responses)
|
||||
for i in range(n):
|
||||
response = responses[i]
|
||||
succeed = is_equiv_chain_of_thought(response, solution)
|
||||
success_list.append(succeed)
|
||||
return {
|
||||
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||
"success": any(s for s in success_list),
|
||||
}
|
||||
|
||||
seed = 41
|
||||
data = datasets.load_dataset("competition_math")
|
||||
train_data = data["train"].shuffle(seed=seed)
|
||||
test_data = data["test"].shuffle(seed=seed)
|
||||
n_tune_data = 20
|
||||
tune_data = [
|
||||
{
|
||||
"problem": train_data[x]["problem"],
|
||||
"solution": train_data[x]["solution"],
|
||||
}
|
||||
for x in range(len(train_data))
|
||||
if train_data[x]["level"] == "Level 1"
|
||||
][:n_tune_data]
|
||||
test_data = [
|
||||
{
|
||||
"problem": test_data[x]["problem"],
|
||||
"solution": test_data[x]["solution"],
|
||||
}
|
||||
for x in range(len(test_data))
|
||||
if test_data[x]["level"] == "Level 1"
|
||||
]
|
||||
print(
|
||||
"max tokens in tuning data's canonical solutions",
|
||||
max([len(x["solution"].split()) for x in tune_data]),
|
||||
)
|
||||
print(len(tune_data), len(test_data))
|
||||
# prompt template
|
||||
prompts = [
|
||||
lambda data: "Given a mathematics problem, determine the answer. Simplify your answer as much as possible.\n###\nProblem: What is the value of $\\sqrt{3! \\cdot 3!}$ expressed as a positive integer?\nAnswer: $\\sqrt{3!\\cdot3!}$ is equal to $\\sqrt{(3!)^2}=3!=3\\cdot2\\cdot1=\\boxed{6}$.\n###\nProblem: %s\nAnswer:"
|
||||
+ data["problem"]
|
||||
]
|
||||
|
||||
try:
|
||||
oai.ChatCompletion.set_cache(seed)
|
||||
vanilla_config = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"temperature": 1,
|
||||
"max_tokens": 2048,
|
||||
"n": 1,
|
||||
"prompt": prompts[0],
|
||||
"stop": "###",
|
||||
}
|
||||
test_data_sample = test_data[0:3]
|
||||
result = oai.ChatCompletion.test(
|
||||
test_data_sample, vanilla_config, success_metrics
|
||||
)
|
||||
test_data_sample = test_data[3:6]
|
||||
result = oai.ChatCompletion.test(
|
||||
test_data_sample,
|
||||
vanilla_config,
|
||||
success_metrics,
|
||||
use_cache=False,
|
||||
agg_method="median",
|
||||
)
|
||||
|
||||
def my_median(results):
|
||||
return np.median(results)
|
||||
|
||||
def my_average(results):
|
||||
return np.mean(results)
|
||||
|
||||
result = oai.ChatCompletion.test(
|
||||
test_data_sample,
|
||||
vanilla_config,
|
||||
success_metrics,
|
||||
use_cache=False,
|
||||
agg_method=my_median,
|
||||
)
|
||||
result = oai.ChatCompletion.test(
|
||||
test_data_sample,
|
||||
vanilla_config,
|
||||
success_metrics,
|
||||
use_cache=False,
|
||||
agg_method={"expected_success": my_median, "success": my_average},
|
||||
)
|
||||
|
||||
print(result)
|
||||
|
||||
config, _ = oai.ChatCompletion.tune(
|
||||
data=tune_data, # the data for tuning
|
||||
metric="expected_success", # the metric to optimize
|
||||
mode="max", # the optimization mode
|
||||
eval_func=success_metrics, # the evaluation function to return the success metrics
|
||||
# log_file_name="logs/math.log", # the log file name
|
||||
inference_budget=0.002, # the inference budget (dollar)
|
||||
optimization_budget=0.01, # the optimization budget (dollar)
|
||||
num_samples=num_samples,
|
||||
prompt=prompts, # the prompt templates to choose from
|
||||
stop="###", # the stop sequence
|
||||
)
|
||||
print("tuned config", config)
|
||||
result = oai.ChatCompletion.test(test_data_sample, config)
|
||||
print("result from tuned config:", result)
|
||||
except (ImportError, NameError) as exc:
|
||||
print(exc)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import openai
|
||||
|
||||
openai.api_key_path = "test/openai/key.txt"
|
||||
test_humaneval(-1)
|
||||
test_math(-1)
|
||||
|
|
|
@ -15,13 +15,24 @@ except ImportError:
|
|||
here = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
def run_notebook(input_nb, output_nb="executed_notebook.ipynb", save=False):
|
||||
def run_notebook(input_nb, output_nb="executed_openai_notebook.ipynb", save=False):
|
||||
try:
|
||||
file_path = os.path.join(here, os.pardir, os.pardir, "notebook", input_nb)
|
||||
with open(file_path) as f:
|
||||
nb = nbformat.read(f, as_version=4)
|
||||
ep = ExecutePreprocessor(timeout=3600, kernel_name="python3")
|
||||
ep.preprocess(nb, {"metadata": {"path": here}})
|
||||
|
||||
output_file_name = "executed_openai_notebook_output.txt"
|
||||
output_file = os.path.join(here, output_file_name)
|
||||
with open(output_file, "a") as f:
|
||||
for cell in nb.cells:
|
||||
if cell.cell_type == "code" and "outputs" in cell:
|
||||
for output in cell.outputs:
|
||||
if "text" in output:
|
||||
f.write(output["text"].strip() + "\n")
|
||||
elif "data" in output and "text/plain" in output["data"]:
|
||||
f.write(output["data"]["text/plain"].strip() + "\n")
|
||||
except CellExecutionError:
|
||||
raise
|
||||
finally:
|
||||
|
@ -48,3 +59,4 @@ def test_integrate_chatgpt(save=False):
|
|||
|
||||
if __name__ == "__main__":
|
||||
test_integrate_chatgpt(save=True)
|
||||
test_integrate_openai(save=True)
|
||||
|
|
|
@ -173,11 +173,10 @@ print(success_metrics([response["text"].rstrip() for response in responses["choi
|
|||
|
||||
#### Evaluate the success rate on the test data
|
||||
|
||||
You can use flaml's `oai.Completion.eval` to evaluate the performance of an entire dataset with the tuned config. To do that you need to set `oai.Completion.data` to the data to evaluate.
|
||||
You can use flaml's `oai.Completion.test` to evaluate the performance of an entire dataset with the tuned config.
|
||||
|
||||
```python
|
||||
oai.Completion.data = test_data
|
||||
result = oai.Completion.eval(analysis.best_config, prune=False, eval_only=True)
|
||||
result = oai.Completion.test(test_data, config)
|
||||
print(result)
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue