ChatGPT support (#942)

* improve max_valid_n and doc

* Update README.md

Co-authored-by: Li Jiang <lijiang1@microsoft.com>

* add support for chatgpt

* notebook

* newline at end of file

* chatgpt notebook

* ChatGPT in Azure

* doc

* math

* warning, timeout, log file name

* handle import error

* doc update; default value

* paper

* doc

* docstr

* eval_func

* prompt and messages

* remove confusing words

* notebook name

---------

Co-authored-by: Li Jiang <lijiang1@microsoft.com>
Co-authored-by: Susan Xueqing Liu <liususan091219@users.noreply.github.com>
This commit is contained in:
Chi Wang 2023-03-10 11:35:36 -08:00 committed by GitHub
parent 3a606930d1
commit 169012f3e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 2741 additions and 81 deletions

View File

@ -1,3 +1,3 @@
from flaml.integrations.oai.completion import Completion
from flaml.integrations.oai.completion import Completion, ChatCompletion
__all__ = ["Completion"]
__all__ = ["Completion", "ChatCompletion"]

View File

@ -13,6 +13,7 @@ try:
APIConnectionError,
)
import diskcache
from urllib3.exceptions import ReadTimeoutError
ERROR = None
except ImportError:
@ -39,8 +40,15 @@ def get_key(config):
class Completion:
"""A class for OpenAI API completion."""
"""A class for OpenAI completion API.
It also supports: ChatCompletion, Azure OpenAI API.
"""
# set of models that support chat completion
chat_models = {"gpt-3.5-turbo"}
# price per 1k tokens
price1K = {
"text-ada-001": 0.0004,
"text-babbage-001": 0.0005,
@ -49,6 +57,7 @@ class Completion:
"code-davinci-002": 0.1,
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
"gpt-3.5-turbo": 0.002,
}
default_search_space = {
@ -70,6 +79,10 @@ class Completion:
# fail a request after hitting RateLimitError for this many seconds
retry_timeout = 60
openai_completion_class = not ERROR and openai.Completion
_total_cost = 0
optimization_budget = None
@classmethod
def set_cache(cls, seed=41, cache_path=".cache"):
"""Set cache path.
@ -95,17 +108,23 @@ class Completion:
# print("using cached response")
return response
retry = 0
openai_completion = (
openai.ChatCompletion
if config["model"] in cls.chat_models
else openai.Completion
)
while eval_only or retry * cls.retry_time < cls.retry_timeout:
try:
response = openai.Completion.create(**config)
response = openai_completion.create(**config)
cls._cache.set(key, response)
return response
except (
ServiceUnavailableError,
APIError,
APIConnectionError,
ReadTimeoutError,
):
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1)
sleep(cls.retry_time)
except RateLimitError:
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
@ -152,7 +171,11 @@ class Completion:
@classmethod
def _get_region_key(cls, config):
# get a key for the valid/invalid region corresponding to the given config
return (config["model"], config["prompt"], config.get("stop"))
return (
config["model"],
config.get("prompt", config.get("messages")),
config.get("stop"),
)
@classmethod
def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions):
@ -172,25 +195,41 @@ class Completion:
Args:
config (dict): Hyperparameter setting for the openai api call.
prune (bool, optional): Whether to enable pruning. Defaults to True.
eval_only (bool, optional): Whether to evaluate only. Defaults to False.
eval_only (bool, optional): Whether to evaluate only (ignore the inference budget and no timeout).
Defaults to False.
Returns:
dict: Evaluation results.
"""
cost = 0
data = cls.data
model = config["model"]
data_length = len(data)
target_n_tokens = (
1000 * cls.inference_budget / cls.price1K[config["model"]]
if cls.inference_budget and cls.price1K.get(config["model"])
target_n_tokens = getattr(cls, "inference_budget", None) and (
1000 * cls.inference_budget / cls.price1K[model]
if cls.inference_budget and cls.price1K.get(model)
else None
)
prune_hp = cls._prune_hp
prune_hp = getattr(cls, "_prune_hp", "n")
metric = cls._metric
config_n = config[prune_hp]
config_n = config.get(prune_hp, 1) # default value in OpenAI is 1
max_tokens = config.get("max_tokens", 16) # default value in OpenAI is 16
region_key = cls._get_region_key(config)
prompt = cls._prompts[config["prompt"]]
if model in cls.chat_models:
# either "prompt" should be in config (for being compatible with non-chat models)
# or "messages" should be in config (for tuning chat models only)
prompt = config.get("prompt")
messages = config.get("messages")
# either prompt or messages should be in config, but not both
assert (prompt is None) != (
messages is None
), "Either prompt or messages should be in config for chat models."
if prompt is None:
messages = cls._messages[messages]
else:
prompt = cls._prompts[prompt]
else:
prompt = cls._prompts[config["prompt"]]
stop = cls._stops and cls._stops[config["stop"]]
if prune and target_n_tokens:
max_valid_n = cls._get_max_valid_n(region_key, max_tokens)
@ -232,8 +271,37 @@ class Completion:
while True: # data_limit <= data_length
# limit the number of data points to avoid rate limit
for i in range(prev_data_limit, data_limit):
logger.debug(
f"num_completions={num_completions}, data instance={i}"
)
data_i = data[i]
params["prompt"] = prompt.format(**data_i)
if prompt is None:
params["messages"] = [
{
"role": m["role"],
"content": m["content"].format(**data_i)
if isinstance(m["content"], str)
else m["content"](data_i),
}
for m in messages
]
elif model in cls.chat_models:
# convert prompt to messages
params["messages"] = [
{
"role": "user",
"content": prompt.format(**data_i)
if isinstance(prompt, str)
else prompt(data_i),
},
]
params.pop("prompt", None)
else:
params["prompt"] = (
prompt.format(**data_i)
if isinstance(prompt, str)
else prompt(data_i)
)
response = cls._get_response(params, eval_only)
if response == -1: # rate limit error, treat as invalid
cls._update_invalid_n(
@ -243,7 +311,11 @@ class Completion:
result["cost"] = cost
return result
# evaluate the quality of the responses
responses = [r["text"].rstrip() for r in response["choices"]]
responses = (
[r["message"]["content"].rstrip() for r in response["choices"]]
if model in cls.chat_models
else [r["text"].rstrip() for r in response["choices"]]
)
n_tokens = (
response["usage"]["completion_tokens"]
if previous_num_completions
@ -260,9 +332,7 @@ class Completion:
# Under Assumption 1, we should count both the input and output tokens in the first query,
# and only count ouput tokens afterwards
query_cost = (
response["usage"]["total_tokens"]
* cls.price1K[config["model"]]
/ 1000
response["usage"]["total_tokens"] * cls.price1K[model] / 1000
)
cls._total_cost += query_cost
cost += query_cost
@ -347,9 +417,7 @@ class Completion:
result[key] /= data_limit
result["total_cost"] = cls._total_cost
result["cost"] = cost
result["inference_cost"] = (
avg_n_tokens * cls.price1K[config["model"]] / 1000
)
result["inference_cost"] = avg_n_tokens * cls.price1K[model] / 1000
if prune and target_n_tokens and not cls.avg_input_tokens:
cls.avg_input_tokens = np.mean(input_tokens)
break
@ -374,6 +442,7 @@ class Completion:
inference_budget=None,
optimization_budget=None,
num_samples=1,
logging_level=logging.WARNING,
**config,
):
"""Tune the parameters for the OpenAI API call.
@ -385,13 +454,42 @@ class Completion:
metric (str): The metric to optimize.
mode (str): The optimization mode, "min" or "max.
eval_func (Callable): The evaluation function for responses.
The function should take a list of responses and a data point as input,
and return a dict of metrics. For example,
```python
def eval_func(responses, **data):
solution = data["solution"]
success_list = []
n = len(responses)
for i in range(n):
response = responses[i]
succeed = is_equiv_chain_of_thought(response, solution)
success_list.append(succeed)
return {
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
"success": any(s for s in success_list),
}
```
log_file_name (str, optional): The log file.
inference_budget (float, optional): The inference budget.
optimization_budget (float, optional): The optimization budget.
num_samples (int, optional): The number of samples to evaluate.
-1 means no hard restriction in the number of trials
and the actual number is decided by optimization_budget. Defaults to 1.
**config (dict): The search space to update over the default search.
For prompt, please provide a string or a list of strings.
For prompt, please provide a string/Callable or a list of strings/Callables.
- If prompt is provided for chat models, it will be converted to messages under role "user".
- Do not provide both prompt and messages for chat models, but provide either of them.
- A string `prompt` template will be used to generate a prompt for each data instance
using `prompt.format(**data)`.
- A callable `prompt` template will be used to generate a prompt for each data instance
using `prompt(data)`.
For stop, please provide a string, a list of strings, or a list of lists of strings.
For messages (chat models only), please provide a list of messages (for a single chat prefix)
or a list of lists of messages (for multiple choices of chat prefix to choose from).
Each message should be a dict with keys "role" and "content".
Returns:
dict: The optimized hyperparameter setting.
@ -399,9 +497,11 @@ class Completion:
"""
if ERROR:
raise ERROR
space = Completion.default_search_space.copy()
space = cls.default_search_space.copy()
if config is not None:
space.update(config)
if "messages" in space:
space.pop("prompt", None)
temperature = space.pop("temperature", None)
top_p = space.pop("top_p", None)
if temperature is not None and top_p is None:
@ -415,58 +515,69 @@ class Completion:
logger.warning(
"temperature and top_p are not recommended to vary together."
)
with diskcache.Cache(cls.cache_path) as cls._cache:
cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
cls.optimization_budget = optimization_budget
cls.inference_budget = inference_budget
cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
cls._prompts = space["prompt"]
cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
cls.optimization_budget = optimization_budget
cls.inference_budget = inference_budget
cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
cls._prompts = space.get("prompt")
if cls._prompts is None:
cls._messages = space.get("messages")
assert isinstance(cls._messages, list) and isinstance(
cls._messages[0], (dict, list)
), "messages must be a list of dicts or a list of lists."
if isinstance(cls._messages[0], dict):
cls._messages = [cls._messages]
space["messages"] = tune.choice(list(range(len(cls._messages))))
else:
assert (
space.get("messages") is None
), "messages and prompt cannot be provided at the same time."
assert isinstance(
cls._prompts, (str, list)
), "prompt must be a string or a list of strings."
if isinstance(cls._prompts, str):
cls._prompts = [cls._prompts]
space["prompt"] = tune.choice(list(range(len(cls._prompts))))
cls._stops = space.get("stop")
if cls._stops:
assert isinstance(
cls._stops, (str, list)
), "stop must be a string, a list of strings, or a list of lists of strings."
if not (
isinstance(cls._stops, list) and isinstance(cls._stops[0], list)
):
cls._stops = [cls._stops]
space["stop"] = tune.choice(list(range(len(cls._stops))))
cls._metric, cls._mode = metric, mode
cls._total_cost = 0 # total optimization cost
cls._eval_func = eval_func
cls.data = data
cls.avg_input_tokens = None
cls._stops = space.get("stop")
if cls._stops:
assert isinstance(
cls._stops, (str, list)
), "stop must be a string, a list of strings, or a list of lists of strings."
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
cls._stops = [cls._stops]
space["stop"] = tune.choice(list(range(len(cls._stops))))
cls._metric, cls._mode = metric, mode
cls._total_cost = 0 # total optimization cost
cls._eval_func = eval_func
cls.data = data
cls.avg_input_tokens = None
search_alg = BlendSearch(
cost_attr="cost",
cost_budget=optimization_budget,
metric=metric,
mode=mode,
space=space,
)
if len(space["model"]) > 1:
# start all the models with the same hp config
config0 = search_alg.suggest("t0")
points_to_evaluate = [config0]
for model in space["model"]:
if model != config0["model"]:
point = config0.copy()
point["model"] = model
points_to_evaluate.append(point)
search_alg = BlendSearch(
cost_attr="cost",
cost_budget=optimization_budget,
metric=metric,
mode=mode,
space=space,
points_to_evaluate=points_to_evaluate,
)
if len(space["model"]) > 1:
# start all the models with the same hp config
config0 = search_alg.suggest("t0")
points_to_evaluate = [config0]
for model in space["model"]:
if model != config0["model"]:
point = config0.copy()
point["model"] = model
points_to_evaluate.append(point)
search_alg = BlendSearch(
cost_attr="cost",
cost_budget=optimization_budget,
metric=metric,
mode=mode,
space=space,
points_to_evaluate=points_to_evaluate,
)
logger.setLevel(logging_level)
with diskcache.Cache(cls.cache_path) as cls._cache:
analysis = tune.run(
cls.eval,
search_alg=search_alg,
@ -474,14 +585,17 @@ class Completion:
log_file_name=log_file_name,
verbose=3,
)
config = analysis.best_config
params = config.copy()
config = analysis.best_config
params = config.copy()
if cls._prompts:
params["prompt"] = cls._prompts[config["prompt"]]
stop = cls._stops and cls._stops[config["stop"]]
params["stop"] = stop
temperature_or_top_p = params.pop("temperature_or_top_p", None)
if temperature_or_top_p:
params.update(temperature_or_top_p)
else:
params["messages"] = cls._messages[config["messages"]]
stop = cls._stops and cls._stops[config["stop"]]
params["stop"] = stop
temperature_or_top_p = params.pop("temperature_or_top_p", None)
if temperature_or_top_p:
params.update(temperature_or_top_p)
return params, analysis
@classmethod
@ -503,7 +617,43 @@ class Completion:
if ERROR:
raise ERROR
params = config.copy()
params["prompt"] = config["prompt"].format(**context)
prompt = config.get("prompt")
if "messages" in config:
params["messages"] = [
{
k: v.format(**context) if isinstance(v, str) else v(context)
for k, v in message.items()
}
for message in config["messages"]
]
params.pop("prompt", None)
elif config["model"] in cls.chat_models:
params["messages"] = [
{
"role": "user",
"content": prompt.format(**context)
if isinstance(prompt, str)
else prompt(context),
}
]
params.pop("prompt", None)
else:
params["prompt"] = (
prompt.format(**context) if isinstance(prompt, str) else prompt(context)
)
if use_cache:
return cls._get_response(params)
return openai.Completion.create(**params)
with diskcache.Cache(cls.cache_path) as cls._cache:
return cls._get_response(params)
return cls.openai_completion_class.create(**params)
class ChatCompletion(Completion):
"""A class for OpenAI API ChatCompletion."""
price1K = {
"gpt-3.5-turbo": 0.002,
}
default_search_space = Completion.default_search_space.copy()
default_search_space["model"] = tune.choice(list(price1K.keys()))
openai_completion_class = not ERROR and openai.ChatCompletion

View File

@ -1 +1 @@
__version__ = "1.1.3"
__version__ = "1.2.0"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -458,6 +458,7 @@
" \"code-davinci-002\": 0.1,\n",
" \"text-davinci-002\": 0.02,\n",
" \"text-davinci-003\": 0.02,\n",
" \"gpt-3.5-turbo\": 0.002,\n",
"}\n",
"\n",
"default_search_space = {\n",

View File

@ -804,5 +804,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@ -120,7 +120,7 @@ setuptools.setup(
"pytorch-forecasting>=0.9.0",
],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
"openai": ["openai==0.23.1", "diskcache", "optuna==2.8.0"],
"openai": ["openai==0.27.0", "diskcache", "optuna==2.8.0"],
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.0.0"],
},
classifiers=[

View File

@ -80,13 +80,36 @@ def test_humaneval(num_samples=1):
oai.Completion.set_cache(seed)
try:
# a minimal tuning example
oai.Completion.tune(
config, _ = oai.Completion.tune(
data=tune_data,
metric="success",
mode="max",
eval_func=success_metrics,
n=1,
)
responses = oai.Completion.create(context=test_data[0], **config)
# a minimal tuning example for tuning chat completion models using the Completion class
config, _ = oai.Completion.tune(
data=tune_data,
metric="success",
mode="max",
eval_func=success_metrics,
n=1,
model="gpt-3.5-turbo",
)
responses = oai.Completion.create(context=test_data[0], **config)
# a minimal tuning example for tuning chat completion models using the Completion class
config, _ = oai.ChatCompletion.tune(
data=tune_data,
metric="success",
mode="max",
eval_func=success_metrics,
n=1,
messages=[{"role": "user", "content": "{prompt}"}],
)
responses = oai.ChatCompletion.create(context=test_data[0], **config)
print(responses)
return
# a more comprehensive tuning example
config, analysis = oai.Completion.tune(
data=tune_data,
@ -94,8 +117,8 @@ def test_humaneval(num_samples=1):
mode="max",
eval_func=success_metrics,
log_file_name="logs/humaneval.log",
inference_budget=0.02,
optimization_budget=5,
inference_budget=0.002,
optimization_budget=2,
num_samples=num_samples,
prompt=[
"{prompt}",

View File

@ -38,5 +38,13 @@ def test_integrate_openai(save=False):
run_notebook("integrate_openai.ipynb", save=save)
@pytest.mark.skipif(
skip,
reason="do not run openai test if openai is not installed",
)
def test_integrate_chatgpt(save=False):
run_notebook("integrate_chatgpt_math.ipynb", save=save)
if __name__ == "__main__":
test_integrate_openai(save=True)
test_integrate_chatgpt(save=True)

View File

@ -1,10 +1,10 @@
FLAML has integrated the OpenAI's completion API. In this example, we will tune several hyperparameters including the temperature, prompt and n to optimize the inference performance of OpenAI's completion API for a code generation task. Our study shows that tuning hyperparameters can significantly affect the utility of the OpenAI API.
FLAML offers a cost-effective hyperparameter optimization technique [EcoOptiGen](https://arxiv.org/abs/2303.04673). In this example, we will tune several hyperparameters for the OpenAI's completion API, including the temperature, prompt and n (number of completions), to optimize the inference performance for a code generation task. Our study shows that tuning hyperparameters can significantly affect the utility of the OpenAI API.
### Prerequisites
Install the [openai] option. The option is available in flaml since version 1.1.3. This feature is subject to change in future versions.
Install the [openai] option. The OpenAI integration is in preview. ChaptGPT support is available since version 1.2.0.
```bash
pip install "flaml[openai]==1.1.3"
pip install "flaml[openai]==1.2.0"
```

View File

@ -91,3 +91,14 @@ year={2023},
url={https://openreview.net/forum?id=0Ij9_q567Ma}
}
```
* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).
```bibtex
@inproceedings{wang2023EcoOptiGen,
title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},
author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},
year={2023},
booktitle={ArXiv preprint arXiv:2303.04673},
}
```