mirror of https://github.com/microsoft/autogen.git
ChatGPT support (#942)
* improve max_valid_n and doc * Update README.md Co-authored-by: Li Jiang <lijiang1@microsoft.com> * add support for chatgpt * notebook * newline at end of file * chatgpt notebook * ChatGPT in Azure * doc * math * warning, timeout, log file name * handle import error * doc update; default value * paper * doc * docstr * eval_func * prompt and messages * remove confusing words * notebook name --------- Co-authored-by: Li Jiang <lijiang1@microsoft.com> Co-authored-by: Susan Xueqing Liu <liususan091219@users.noreply.github.com>
This commit is contained in:
parent
3a606930d1
commit
169012f3e7
|
@ -1,3 +1,3 @@
|
|||
from flaml.integrations.oai.completion import Completion
|
||||
from flaml.integrations.oai.completion import Completion, ChatCompletion
|
||||
|
||||
__all__ = ["Completion"]
|
||||
__all__ = ["Completion", "ChatCompletion"]
|
||||
|
|
|
@ -13,6 +13,7 @@ try:
|
|||
APIConnectionError,
|
||||
)
|
||||
import diskcache
|
||||
from urllib3.exceptions import ReadTimeoutError
|
||||
|
||||
ERROR = None
|
||||
except ImportError:
|
||||
|
@ -39,8 +40,15 @@ def get_key(config):
|
|||
|
||||
|
||||
class Completion:
|
||||
"""A class for OpenAI API completion."""
|
||||
"""A class for OpenAI completion API.
|
||||
|
||||
It also supports: ChatCompletion, Azure OpenAI API.
|
||||
"""
|
||||
|
||||
# set of models that support chat completion
|
||||
chat_models = {"gpt-3.5-turbo"}
|
||||
|
||||
# price per 1k tokens
|
||||
price1K = {
|
||||
"text-ada-001": 0.0004,
|
||||
"text-babbage-001": 0.0005,
|
||||
|
@ -49,6 +57,7 @@ class Completion:
|
|||
"code-davinci-002": 0.1,
|
||||
"text-davinci-002": 0.02,
|
||||
"text-davinci-003": 0.02,
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
}
|
||||
|
||||
default_search_space = {
|
||||
|
@ -70,6 +79,10 @@ class Completion:
|
|||
# fail a request after hitting RateLimitError for this many seconds
|
||||
retry_timeout = 60
|
||||
|
||||
openai_completion_class = not ERROR and openai.Completion
|
||||
_total_cost = 0
|
||||
optimization_budget = None
|
||||
|
||||
@classmethod
|
||||
def set_cache(cls, seed=41, cache_path=".cache"):
|
||||
"""Set cache path.
|
||||
|
@ -95,17 +108,23 @@ class Completion:
|
|||
# print("using cached response")
|
||||
return response
|
||||
retry = 0
|
||||
openai_completion = (
|
||||
openai.ChatCompletion
|
||||
if config["model"] in cls.chat_models
|
||||
else openai.Completion
|
||||
)
|
||||
while eval_only or retry * cls.retry_time < cls.retry_timeout:
|
||||
try:
|
||||
response = openai.Completion.create(**config)
|
||||
response = openai_completion.create(**config)
|
||||
cls._cache.set(key, response)
|
||||
return response
|
||||
except (
|
||||
ServiceUnavailableError,
|
||||
APIError,
|
||||
APIConnectionError,
|
||||
ReadTimeoutError,
|
||||
):
|
||||
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
|
||||
logger.warning(f"retrying in {cls.retry_time} seconds...", exc_info=1)
|
||||
sleep(cls.retry_time)
|
||||
except RateLimitError:
|
||||
logger.info(f"retrying in {cls.retry_time} seconds...", exc_info=1)
|
||||
|
@ -152,7 +171,11 @@ class Completion:
|
|||
@classmethod
|
||||
def _get_region_key(cls, config):
|
||||
# get a key for the valid/invalid region corresponding to the given config
|
||||
return (config["model"], config["prompt"], config.get("stop"))
|
||||
return (
|
||||
config["model"],
|
||||
config.get("prompt", config.get("messages")),
|
||||
config.get("stop"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _update_invalid_n(cls, prune, region_key, max_tokens, num_completions):
|
||||
|
@ -172,25 +195,41 @@ class Completion:
|
|||
Args:
|
||||
config (dict): Hyperparameter setting for the openai api call.
|
||||
prune (bool, optional): Whether to enable pruning. Defaults to True.
|
||||
eval_only (bool, optional): Whether to evaluate only. Defaults to False.
|
||||
eval_only (bool, optional): Whether to evaluate only (ignore the inference budget and no timeout).
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results.
|
||||
"""
|
||||
cost = 0
|
||||
data = cls.data
|
||||
model = config["model"]
|
||||
data_length = len(data)
|
||||
target_n_tokens = (
|
||||
1000 * cls.inference_budget / cls.price1K[config["model"]]
|
||||
if cls.inference_budget and cls.price1K.get(config["model"])
|
||||
target_n_tokens = getattr(cls, "inference_budget", None) and (
|
||||
1000 * cls.inference_budget / cls.price1K[model]
|
||||
if cls.inference_budget and cls.price1K.get(model)
|
||||
else None
|
||||
)
|
||||
prune_hp = cls._prune_hp
|
||||
prune_hp = getattr(cls, "_prune_hp", "n")
|
||||
metric = cls._metric
|
||||
config_n = config[prune_hp]
|
||||
config_n = config.get(prune_hp, 1) # default value in OpenAI is 1
|
||||
max_tokens = config.get("max_tokens", 16) # default value in OpenAI is 16
|
||||
region_key = cls._get_region_key(config)
|
||||
prompt = cls._prompts[config["prompt"]]
|
||||
if model in cls.chat_models:
|
||||
# either "prompt" should be in config (for being compatible with non-chat models)
|
||||
# or "messages" should be in config (for tuning chat models only)
|
||||
prompt = config.get("prompt")
|
||||
messages = config.get("messages")
|
||||
# either prompt or messages should be in config, but not both
|
||||
assert (prompt is None) != (
|
||||
messages is None
|
||||
), "Either prompt or messages should be in config for chat models."
|
||||
if prompt is None:
|
||||
messages = cls._messages[messages]
|
||||
else:
|
||||
prompt = cls._prompts[prompt]
|
||||
else:
|
||||
prompt = cls._prompts[config["prompt"]]
|
||||
stop = cls._stops and cls._stops[config["stop"]]
|
||||
if prune and target_n_tokens:
|
||||
max_valid_n = cls._get_max_valid_n(region_key, max_tokens)
|
||||
|
@ -232,8 +271,37 @@ class Completion:
|
|||
while True: # data_limit <= data_length
|
||||
# limit the number of data points to avoid rate limit
|
||||
for i in range(prev_data_limit, data_limit):
|
||||
logger.debug(
|
||||
f"num_completions={num_completions}, data instance={i}"
|
||||
)
|
||||
data_i = data[i]
|
||||
params["prompt"] = prompt.format(**data_i)
|
||||
if prompt is None:
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": m["role"],
|
||||
"content": m["content"].format(**data_i)
|
||||
if isinstance(m["content"], str)
|
||||
else m["content"](data_i),
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
elif model in cls.chat_models:
|
||||
# convert prompt to messages
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt.format(**data_i)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_i),
|
||||
},
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
else:
|
||||
params["prompt"] = (
|
||||
prompt.format(**data_i)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(data_i)
|
||||
)
|
||||
response = cls._get_response(params, eval_only)
|
||||
if response == -1: # rate limit error, treat as invalid
|
||||
cls._update_invalid_n(
|
||||
|
@ -243,7 +311,11 @@ class Completion:
|
|||
result["cost"] = cost
|
||||
return result
|
||||
# evaluate the quality of the responses
|
||||
responses = [r["text"].rstrip() for r in response["choices"]]
|
||||
responses = (
|
||||
[r["message"]["content"].rstrip() for r in response["choices"]]
|
||||
if model in cls.chat_models
|
||||
else [r["text"].rstrip() for r in response["choices"]]
|
||||
)
|
||||
n_tokens = (
|
||||
response["usage"]["completion_tokens"]
|
||||
if previous_num_completions
|
||||
|
@ -260,9 +332,7 @@ class Completion:
|
|||
# Under Assumption 1, we should count both the input and output tokens in the first query,
|
||||
# and only count ouput tokens afterwards
|
||||
query_cost = (
|
||||
response["usage"]["total_tokens"]
|
||||
* cls.price1K[config["model"]]
|
||||
/ 1000
|
||||
response["usage"]["total_tokens"] * cls.price1K[model] / 1000
|
||||
)
|
||||
cls._total_cost += query_cost
|
||||
cost += query_cost
|
||||
|
@ -347,9 +417,7 @@ class Completion:
|
|||
result[key] /= data_limit
|
||||
result["total_cost"] = cls._total_cost
|
||||
result["cost"] = cost
|
||||
result["inference_cost"] = (
|
||||
avg_n_tokens * cls.price1K[config["model"]] / 1000
|
||||
)
|
||||
result["inference_cost"] = avg_n_tokens * cls.price1K[model] / 1000
|
||||
if prune and target_n_tokens and not cls.avg_input_tokens:
|
||||
cls.avg_input_tokens = np.mean(input_tokens)
|
||||
break
|
||||
|
@ -374,6 +442,7 @@ class Completion:
|
|||
inference_budget=None,
|
||||
optimization_budget=None,
|
||||
num_samples=1,
|
||||
logging_level=logging.WARNING,
|
||||
**config,
|
||||
):
|
||||
"""Tune the parameters for the OpenAI API call.
|
||||
|
@ -385,13 +454,42 @@ class Completion:
|
|||
metric (str): The metric to optimize.
|
||||
mode (str): The optimization mode, "min" or "max.
|
||||
eval_func (Callable): The evaluation function for responses.
|
||||
The function should take a list of responses and a data point as input,
|
||||
and return a dict of metrics. For example,
|
||||
|
||||
```python
|
||||
def eval_func(responses, **data):
|
||||
solution = data["solution"]
|
||||
success_list = []
|
||||
n = len(responses)
|
||||
for i in range(n):
|
||||
response = responses[i]
|
||||
succeed = is_equiv_chain_of_thought(response, solution)
|
||||
success_list.append(succeed)
|
||||
return {
|
||||
"expected_success": 1 - pow(1 - sum(success_list) / n, n),
|
||||
"success": any(s for s in success_list),
|
||||
}
|
||||
```
|
||||
|
||||
log_file_name (str, optional): The log file.
|
||||
inference_budget (float, optional): The inference budget.
|
||||
optimization_budget (float, optional): The optimization budget.
|
||||
num_samples (int, optional): The number of samples to evaluate.
|
||||
-1 means no hard restriction in the number of trials
|
||||
and the actual number is decided by optimization_budget. Defaults to 1.
|
||||
**config (dict): The search space to update over the default search.
|
||||
For prompt, please provide a string or a list of strings.
|
||||
For prompt, please provide a string/Callable or a list of strings/Callables.
|
||||
- If prompt is provided for chat models, it will be converted to messages under role "user".
|
||||
- Do not provide both prompt and messages for chat models, but provide either of them.
|
||||
- A string `prompt` template will be used to generate a prompt for each data instance
|
||||
using `prompt.format(**data)`.
|
||||
- A callable `prompt` template will be used to generate a prompt for each data instance
|
||||
using `prompt(data)`.
|
||||
For stop, please provide a string, a list of strings, or a list of lists of strings.
|
||||
For messages (chat models only), please provide a list of messages (for a single chat prefix)
|
||||
or a list of lists of messages (for multiple choices of chat prefix to choose from).
|
||||
Each message should be a dict with keys "role" and "content".
|
||||
|
||||
Returns:
|
||||
dict: The optimized hyperparameter setting.
|
||||
|
@ -399,9 +497,11 @@ class Completion:
|
|||
"""
|
||||
if ERROR:
|
||||
raise ERROR
|
||||
space = Completion.default_search_space.copy()
|
||||
space = cls.default_search_space.copy()
|
||||
if config is not None:
|
||||
space.update(config)
|
||||
if "messages" in space:
|
||||
space.pop("prompt", None)
|
||||
temperature = space.pop("temperature", None)
|
||||
top_p = space.pop("top_p", None)
|
||||
if temperature is not None and top_p is None:
|
||||
|
@ -415,58 +515,69 @@ class Completion:
|
|||
logger.warning(
|
||||
"temperature and top_p are not recommended to vary together."
|
||||
)
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
|
||||
cls.optimization_budget = optimization_budget
|
||||
cls.inference_budget = inference_budget
|
||||
cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
|
||||
cls._prompts = space["prompt"]
|
||||
cls._max_valid_n_per_max_tokens, cls._min_invalid_n_per_max_tokens = {}, {}
|
||||
cls.optimization_budget = optimization_budget
|
||||
cls.inference_budget = inference_budget
|
||||
cls._prune_hp = "best_of" if space.get("best_of", 1) != 1 else "n"
|
||||
cls._prompts = space.get("prompt")
|
||||
if cls._prompts is None:
|
||||
cls._messages = space.get("messages")
|
||||
assert isinstance(cls._messages, list) and isinstance(
|
||||
cls._messages[0], (dict, list)
|
||||
), "messages must be a list of dicts or a list of lists."
|
||||
if isinstance(cls._messages[0], dict):
|
||||
cls._messages = [cls._messages]
|
||||
space["messages"] = tune.choice(list(range(len(cls._messages))))
|
||||
else:
|
||||
assert (
|
||||
space.get("messages") is None
|
||||
), "messages and prompt cannot be provided at the same time."
|
||||
assert isinstance(
|
||||
cls._prompts, (str, list)
|
||||
), "prompt must be a string or a list of strings."
|
||||
if isinstance(cls._prompts, str):
|
||||
cls._prompts = [cls._prompts]
|
||||
space["prompt"] = tune.choice(list(range(len(cls._prompts))))
|
||||
cls._stops = space.get("stop")
|
||||
if cls._stops:
|
||||
assert isinstance(
|
||||
cls._stops, (str, list)
|
||||
), "stop must be a string, a list of strings, or a list of lists of strings."
|
||||
if not (
|
||||
isinstance(cls._stops, list) and isinstance(cls._stops[0], list)
|
||||
):
|
||||
cls._stops = [cls._stops]
|
||||
space["stop"] = tune.choice(list(range(len(cls._stops))))
|
||||
cls._metric, cls._mode = metric, mode
|
||||
cls._total_cost = 0 # total optimization cost
|
||||
cls._eval_func = eval_func
|
||||
cls.data = data
|
||||
cls.avg_input_tokens = None
|
||||
cls._stops = space.get("stop")
|
||||
if cls._stops:
|
||||
assert isinstance(
|
||||
cls._stops, (str, list)
|
||||
), "stop must be a string, a list of strings, or a list of lists of strings."
|
||||
if not (isinstance(cls._stops, list) and isinstance(cls._stops[0], list)):
|
||||
cls._stops = [cls._stops]
|
||||
space["stop"] = tune.choice(list(range(len(cls._stops))))
|
||||
cls._metric, cls._mode = metric, mode
|
||||
cls._total_cost = 0 # total optimization cost
|
||||
cls._eval_func = eval_func
|
||||
cls.data = data
|
||||
cls.avg_input_tokens = None
|
||||
|
||||
search_alg = BlendSearch(
|
||||
cost_attr="cost",
|
||||
cost_budget=optimization_budget,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
space=space,
|
||||
)
|
||||
if len(space["model"]) > 1:
|
||||
# start all the models with the same hp config
|
||||
config0 = search_alg.suggest("t0")
|
||||
points_to_evaluate = [config0]
|
||||
for model in space["model"]:
|
||||
if model != config0["model"]:
|
||||
point = config0.copy()
|
||||
point["model"] = model
|
||||
points_to_evaluate.append(point)
|
||||
search_alg = BlendSearch(
|
||||
cost_attr="cost",
|
||||
cost_budget=optimization_budget,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
space=space,
|
||||
points_to_evaluate=points_to_evaluate,
|
||||
)
|
||||
if len(space["model"]) > 1:
|
||||
# start all the models with the same hp config
|
||||
config0 = search_alg.suggest("t0")
|
||||
points_to_evaluate = [config0]
|
||||
for model in space["model"]:
|
||||
if model != config0["model"]:
|
||||
point = config0.copy()
|
||||
point["model"] = model
|
||||
points_to_evaluate.append(point)
|
||||
search_alg = BlendSearch(
|
||||
cost_attr="cost",
|
||||
cost_budget=optimization_budget,
|
||||
metric=metric,
|
||||
mode=mode,
|
||||
space=space,
|
||||
points_to_evaluate=points_to_evaluate,
|
||||
)
|
||||
logger.setLevel(logging_level)
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
analysis = tune.run(
|
||||
cls.eval,
|
||||
search_alg=search_alg,
|
||||
|
@ -474,14 +585,17 @@ class Completion:
|
|||
log_file_name=log_file_name,
|
||||
verbose=3,
|
||||
)
|
||||
config = analysis.best_config
|
||||
params = config.copy()
|
||||
config = analysis.best_config
|
||||
params = config.copy()
|
||||
if cls._prompts:
|
||||
params["prompt"] = cls._prompts[config["prompt"]]
|
||||
stop = cls._stops and cls._stops[config["stop"]]
|
||||
params["stop"] = stop
|
||||
temperature_or_top_p = params.pop("temperature_or_top_p", None)
|
||||
if temperature_or_top_p:
|
||||
params.update(temperature_or_top_p)
|
||||
else:
|
||||
params["messages"] = cls._messages[config["messages"]]
|
||||
stop = cls._stops and cls._stops[config["stop"]]
|
||||
params["stop"] = stop
|
||||
temperature_or_top_p = params.pop("temperature_or_top_p", None)
|
||||
if temperature_or_top_p:
|
||||
params.update(temperature_or_top_p)
|
||||
return params, analysis
|
||||
|
||||
@classmethod
|
||||
|
@ -503,7 +617,43 @@ class Completion:
|
|||
if ERROR:
|
||||
raise ERROR
|
||||
params = config.copy()
|
||||
params["prompt"] = config["prompt"].format(**context)
|
||||
prompt = config.get("prompt")
|
||||
if "messages" in config:
|
||||
params["messages"] = [
|
||||
{
|
||||
k: v.format(**context) if isinstance(v, str) else v(context)
|
||||
for k, v in message.items()
|
||||
}
|
||||
for message in config["messages"]
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
elif config["model"] in cls.chat_models:
|
||||
params["messages"] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt.format(**context)
|
||||
if isinstance(prompt, str)
|
||||
else prompt(context),
|
||||
}
|
||||
]
|
||||
params.pop("prompt", None)
|
||||
else:
|
||||
params["prompt"] = (
|
||||
prompt.format(**context) if isinstance(prompt, str) else prompt(context)
|
||||
)
|
||||
if use_cache:
|
||||
return cls._get_response(params)
|
||||
return openai.Completion.create(**params)
|
||||
with diskcache.Cache(cls.cache_path) as cls._cache:
|
||||
return cls._get_response(params)
|
||||
return cls.openai_completion_class.create(**params)
|
||||
|
||||
|
||||
class ChatCompletion(Completion):
|
||||
"""A class for OpenAI API ChatCompletion."""
|
||||
|
||||
price1K = {
|
||||
"gpt-3.5-turbo": 0.002,
|
||||
}
|
||||
|
||||
default_search_space = Completion.default_search_space.copy()
|
||||
default_search_space["model"] = tune.choice(list(price1K.keys()))
|
||||
openai_completion_class = not ERROR and openai.ChatCompletion
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "1.1.3"
|
||||
__version__ = "1.2.0"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -458,6 +458,7 @@
|
|||
" \"code-davinci-002\": 0.1,\n",
|
||||
" \"text-davinci-002\": 0.02,\n",
|
||||
" \"text-davinci-003\": 0.02,\n",
|
||||
" \"gpt-3.5-turbo\": 0.002,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"default_search_space = {\n",
|
||||
|
|
|
@ -804,5 +804,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
|
||||
}
|
||||
|
|
2
setup.py
2
setup.py
|
@ -120,7 +120,7 @@ setuptools.setup(
|
|||
"pytorch-forecasting>=0.9.0",
|
||||
],
|
||||
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
|
||||
"openai": ["openai==0.23.1", "diskcache", "optuna==2.8.0"],
|
||||
"openai": ["openai==0.27.0", "diskcache", "optuna==2.8.0"],
|
||||
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.0.0"],
|
||||
},
|
||||
classifiers=[
|
||||
|
|
|
@ -80,13 +80,36 @@ def test_humaneval(num_samples=1):
|
|||
oai.Completion.set_cache(seed)
|
||||
try:
|
||||
# a minimal tuning example
|
||||
oai.Completion.tune(
|
||||
config, _ = oai.Completion.tune(
|
||||
data=tune_data,
|
||||
metric="success",
|
||||
mode="max",
|
||||
eval_func=success_metrics,
|
||||
n=1,
|
||||
)
|
||||
responses = oai.Completion.create(context=test_data[0], **config)
|
||||
# a minimal tuning example for tuning chat completion models using the Completion class
|
||||
config, _ = oai.Completion.tune(
|
||||
data=tune_data,
|
||||
metric="success",
|
||||
mode="max",
|
||||
eval_func=success_metrics,
|
||||
n=1,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
responses = oai.Completion.create(context=test_data[0], **config)
|
||||
# a minimal tuning example for tuning chat completion models using the Completion class
|
||||
config, _ = oai.ChatCompletion.tune(
|
||||
data=tune_data,
|
||||
metric="success",
|
||||
mode="max",
|
||||
eval_func=success_metrics,
|
||||
n=1,
|
||||
messages=[{"role": "user", "content": "{prompt}"}],
|
||||
)
|
||||
responses = oai.ChatCompletion.create(context=test_data[0], **config)
|
||||
print(responses)
|
||||
return
|
||||
# a more comprehensive tuning example
|
||||
config, analysis = oai.Completion.tune(
|
||||
data=tune_data,
|
||||
|
@ -94,8 +117,8 @@ def test_humaneval(num_samples=1):
|
|||
mode="max",
|
||||
eval_func=success_metrics,
|
||||
log_file_name="logs/humaneval.log",
|
||||
inference_budget=0.02,
|
||||
optimization_budget=5,
|
||||
inference_budget=0.002,
|
||||
optimization_budget=2,
|
||||
num_samples=num_samples,
|
||||
prompt=[
|
||||
"{prompt}",
|
||||
|
|
|
@ -38,5 +38,13 @@ def test_integrate_openai(save=False):
|
|||
run_notebook("integrate_openai.ipynb", save=save)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
skip,
|
||||
reason="do not run openai test if openai is not installed",
|
||||
)
|
||||
def test_integrate_chatgpt(save=False):
|
||||
run_notebook("integrate_chatgpt_math.ipynb", save=save)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_integrate_openai(save=True)
|
||||
test_integrate_chatgpt(save=True)
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
FLAML has integrated the OpenAI's completion API. In this example, we will tune several hyperparameters including the temperature, prompt and n to optimize the inference performance of OpenAI's completion API for a code generation task. Our study shows that tuning hyperparameters can significantly affect the utility of the OpenAI API.
|
||||
FLAML offers a cost-effective hyperparameter optimization technique [EcoOptiGen](https://arxiv.org/abs/2303.04673). In this example, we will tune several hyperparameters for the OpenAI's completion API, including the temperature, prompt and n (number of completions), to optimize the inference performance for a code generation task. Our study shows that tuning hyperparameters can significantly affect the utility of the OpenAI API.
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Install the [openai] option. The option is available in flaml since version 1.1.3. This feature is subject to change in future versions.
|
||||
Install the [openai] option. The OpenAI integration is in preview. ChaptGPT support is available since version 1.2.0.
|
||||
```bash
|
||||
pip install "flaml[openai]==1.1.3"
|
||||
pip install "flaml[openai]==1.2.0"
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -91,3 +91,14 @@ year={2023},
|
|||
url={https://openreview.net/forum?id=0Ij9_q567Ma}
|
||||
}
|
||||
```
|
||||
|
||||
* [Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference](https://arxiv.org/abs/2303.04673). Chi Wang, Susan Xueqing Liu, Ahmed H. Awadallah. ArXiv preprint arXiv:2303.04673 (2023).
|
||||
|
||||
```bibtex
|
||||
@inproceedings{wang2023EcoOptiGen,
|
||||
title={Cost-Effective Hyperparameter Optimization for Large Language Model Generation Inference},
|
||||
author={Chi Wang and Susan Xueqing Liu and Ahmed H. Awadallah},
|
||||
year={2023},
|
||||
booktitle={ArXiv preprint arXiv:2303.04673},
|
||||
}
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue