Update readme for flaml.tune (#137)

* add time_budget_s for bs in readme

* version update

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Qingyun Wu 2021-07-24 20:10:43 -04:00 committed by GitHub
parent 95aa719b01
commit 58c0ec959d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 585 additions and 399 deletions

View File

@ -18,11 +18,8 @@ This source file is adapted here because ray does not fully support Windows.
Copyright (c) Microsoft Corporation.
'''
import copy
import glob
import logging
import os
import time
from typing import Dict, Optional, Union, List, Tuple
from typing import Any, Dict, Optional, Union, List, Tuple
import pickle
from .variant_generator import parse_spec_vars
from ..tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
@ -51,12 +48,6 @@ UNDEFINED_METRIC_MODE = str(
"or pass them to `tune.run()`.")
_logged = set()
_disabled = False
_periodic_log = False
_last_logged = 0.0
class Searcher:
"""Abstract class for wrapping suggesting algorithms.
Custom algorithms can extend this class easily by overriding the
@ -341,23 +332,14 @@ class ConcurrencyLimiter(Searcher):
try:
import optuna as ot
from optuna.trial import TrialState as OptunaTrialState
from optuna.samplers import BaseSampler
except ImportError:
ot = None
OptunaTrialState = None
BaseSampler = None
class _Param:
def __getattr__(self, item):
def _inner(*args, **kwargs):
return (item, args, kwargs)
return _inner
param = _Param()
# (Optional) Default (anonymous) metric when using tune.report(x)
DEFAULT_METRIC = "_metric"
@ -395,13 +377,21 @@ class OptunaSearch(Searcher):
configurations.
sampler (optuna.samplers.BaseSampler): Optuna sampler used to
draw hyperparameter configurations. Defaults to ``TPESampler``.
seed (int): The random seed for the sampler
seed (int): Seed to initialize sampler with. This parameter is only
used when ``sampler=None``. In all other cases, the sampler
you pass should be initialized with the seed already.
evaluated_rewards (list): If you have previously evaluated the
parameters passed in as points_to_evaluate you can avoid
re-running those trials by passing in the reward attributes
as a list so the optimiser can be told the results without
needing to re-compute the trial. Must be the same length as
points_to_evaluate.
Tune automatically converts search spaces to Optuna's format:
.. code-block:: python
from ray.tune.suggest.optuna import OptunaSearch
config = {
"a": tune.uniform(6, 8)
"b": tune.uniform(10, 20)
"b": tune.loguniform(1e-4, 1e-2)
}
optuna_search = OptunaSearch(
metric="loss",
@ -410,12 +400,13 @@ class OptunaSearch(Searcher):
If you would like to pass the search space manually, the code would
look like this:
.. code-block:: python
from ray.tune.suggest.optuna import OptunaSearch, param
space = [
param.suggest_uniform("a", 6, 8),
param.suggest_uniform("b", 10, 20)
]
algo = OptunaSearch(
from ray.tune.suggest.optuna import OptunaSearch
import optuna
config = {
"a": optuna.distributions.UniformDistribution(6, 8),
"b": optuna.distributions.LogUniformDistribution(1e-4, 1e-2),
}
optuna_search = OptunaSearch(
space,
metric="loss",
mode="min")
@ -429,7 +420,8 @@ class OptunaSearch(Searcher):
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
sampler: Optional[BaseSampler] = None,
seed: Optional[int] = None):
seed: Optional[int] = None,
evaluated_rewards: Optional[List] = None):
assert ot is not None, (
"Optuna must be installed! Run `pip install optuna`.")
super(OptunaSearch, self).__init__(
@ -443,22 +435,39 @@ class OptunaSearch(Searcher):
if domain_vars or grid_vars:
logger.warning(
UNRESOLVED_SEARCH_SPACE.format(
par="space", cls=type(self)))
par="space", cls=type(self).__name__))
space = self.convert_search_space(space)
else:
# Flatten to support nested dicts
space = flatten_dict(space, "/")
# Deprecate: 1.5
if isinstance(space, list):
logger.warning(
"Passing lists of `param.suggest_*()` calls to OptunaSearch "
"as a search space is deprecated and will be removed in "
"a future release of Ray. Please pass a dict mapping "
"to `optuna.distributions` objects instead.")
self._space = space
self._points_to_evaluate = points_to_evaluate
self._points_to_evaluate = points_to_evaluate or []
self._evaluated_rewards = evaluated_rewards
self._study_name = "optuna" # Fixed study name for in-memory storage
if sampler and seed:
logger.warning(
"You passed an initialized sampler to `OptunaSearch`. The "
"`seed` parameter has to be passed to the sampler directly "
"and will be ignored.")
self._sampler = sampler or ot.samplers.TPESampler(seed=seed)
assert isinstance(self._sampler, BaseSampler), \
"You can only pass an instance of `optuna.samplers.BaseSampler` " \
"as a sampler to `OptunaSearcher`."
self._pruner = ot.pruners.NopPruner()
self._storage = ot.storages.InMemoryStorage()
self._ot_trials = {}
self._ot_study = None
if self._space:
@ -469,14 +478,26 @@ class OptunaSearch(Searcher):
# If only a mode was passed, use anonymous metric
self._metric = DEFAULT_METRIC
pruner = ot.pruners.NopPruner()
storage = ot.storages.InMemoryStorage()
self._ot_study = ot.study.create_study(
storage=self._storage,
storage=storage,
sampler=self._sampler,
pruner=self._pruner,
pruner=pruner,
study_name=self._study_name,
direction="minimize" if mode == "min" else "maximize",
load_if_exists=True)
if self._points_to_evaluate:
if self._evaluated_rewards:
for point, reward in zip(self._points_to_evaluate,
self._evaluated_rewards):
self.add_evaluated_point(point, reward)
else:
for point in self._points_to_evaluate:
self._ot_study.enqueue_trial(point)
def set_search_properties(self, metric: Optional[str], mode: Optional[str],
config: Dict) -> bool:
if self._space:
@ -503,22 +524,28 @@ class OptunaSearch(Searcher):
metric=self._metric,
mode=self._mode))
if trial_id not in self._ot_trials:
ot_trial_id = self._storage.create_new_trial(
self._ot_study._study_id)
self._ot_trials[trial_id] = ot.trial.Trial(self._ot_study,
ot_trial_id)
ot_trial = self._ot_trials[trial_id]
if isinstance(self._space, list):
# Keep for backwards compatibility
# Deprecate: 1.5
if trial_id not in self._ot_trials:
self._ot_trials[trial_id] = self._ot_study.ask()
ot_trial = self._ot_trials[trial_id]
if self._points_to_evaluate:
params = self._points_to_evaluate.pop(0)
else:
# getattr will fetch the trial.suggest_ function on Optuna trials
params = {
args[0] if len(args) > 0 else kwargs["name"]: getattr(
ot_trial, fn)(*args, **kwargs)
for (fn, args, kwargs) in self._space
}
else:
# Use Optuna ask interface (since version 2.6.0)
if trial_id not in self._ot_trials:
self._ot_trials[trial_id] = self._ot_study.ask(
fixed_distributions=self._space)
ot_trial = self._ot_trials[trial_id]
params = ot_trial.params
return unflatten_dict(params)
def on_trial_result(self, trial_id: str, result: Dict):
@ -532,32 +559,82 @@ class OptunaSearch(Searcher):
result: Optional[Dict] = None,
error: bool = False):
ot_trial = self._ot_trials[trial_id]
ot_trial_id = ot_trial._trial_id
self._storage.set_trial_value(ot_trial_id, result.get(
self.metric, None))
self._storage.set_trial_state(ot_trial_id,
ot.trial.TrialState.COMPLETE)
val = result.get(self.metric, None) if result else None
ot_trial_state = OptunaTrialState.COMPLETE
if val is None:
if error:
ot_trial_state = OptunaTrialState.FAIL
else:
ot_trial_state = OptunaTrialState.PRUNED
try:
self._ot_study.tell(ot_trial, val, state=ot_trial_state)
except ValueError as exc:
logger.warning(exc) # E.g. if NaN was reported
def add_evaluated_point(self,
parameters: Dict,
value: float,
error: bool = False,
pruned: bool = False,
intermediate_values: Optional[List[float]] = None):
if not self._space:
raise RuntimeError(
UNDEFINED_SEARCH_SPACE.format(
cls=self.__class__.__name__, space="space"))
if not self._metric or not self._mode:
raise RuntimeError(
UNDEFINED_METRIC_MODE.format(
cls=self.__class__.__name__,
metric=self._metric,
mode=self._mode))
ot_trial_state = OptunaTrialState.COMPLETE
if error:
ot_trial_state = OptunaTrialState.FAIL
elif pruned:
ot_trial_state = OptunaTrialState.PRUNED
if intermediate_values:
intermediate_values_dict = {
i: value
for i, value in enumerate(intermediate_values)
}
else:
intermediate_values_dict = None
trial = ot.trial.create_trial(
state=ot_trial_state,
value=value,
params=parameters,
distributions=self._space,
intermediate_values=intermediate_values_dict)
self._ot_study.add_trial(trial)
def save(self, checkpoint_path: str):
save_object = (self._storage, self._pruner, self._sampler,
self._ot_trials, self._ot_study,
self._points_to_evaluate)
save_object = (self._sampler, self._ot_trials, self._ot_study,
self._points_to_evaluate, self._evaluated_rewards)
with open(checkpoint_path, "wb") as outputFile:
pickle.dump(save_object, outputFile)
def restore(self, checkpoint_path: str):
with open(checkpoint_path, "rb") as inputFile:
save_object = pickle.load(inputFile)
self._storage, self._pruner, self._sampler, \
self._ot_trials, self._ot_study, \
self._points_to_evaluate = save_object
if len(save_object) == 5:
self._sampler, self._ot_trials, self._ot_study, \
self._points_to_evaluate, self._evaluated_rewards = save_object
else:
# Backwards compatibility
self._sampler, self._ot_trials, self._ot_study, \
self._points_to_evaluate = save_object
@staticmethod
def convert_search_space(spec: Dict) -> List[Tuple]:
def convert_search_space(spec: Dict) -> Dict[str, Any]:
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
if not domain_vars and not grid_vars:
return []
return {}
if grid_vars:
raise ValueError(
@ -568,13 +645,18 @@ class OptunaSearch(Searcher):
spec = flatten_dict(spec, prevent_delimiter=True)
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
def resolve_value(par: str, domain: Domain) -> Tuple:
def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
quantize = None
sampler = domain.get_sampler()
if isinstance(sampler, Quantized):
quantize = sampler.q
sampler = sampler.sampler
if isinstance(sampler, LogUniform):
logger.warning(
"Optuna does not handle quantization in loguniform "
"sampling. The parameter will be passed but it will "
"probably be ignored.")
if isinstance(domain, Float):
if isinstance(sampler, LogUniform):
@ -582,28 +664,31 @@ class OptunaSearch(Searcher):
logger.warning(
"Optuna does not support both quantization and "
"sampling from LogUniform. Dropped quantization.")
return param.suggest_loguniform(par, domain.lower,
domain.upper)
return ot.distributions.LogUniformDistribution(
domain.lower, domain.upper)
elif isinstance(sampler, Uniform):
if quantize:
return param.suggest_discrete_uniform(
par, domain.lower, domain.upper, quantize)
return param.suggest_uniform(par, domain.lower,
domain.upper)
return ot.distributions.DiscreteUniformDistribution(
domain.lower, domain.upper, quantize)
return ot.distributions.UniformDistribution(
domain.lower, domain.upper)
elif isinstance(domain, Integer):
if isinstance(sampler, LogUniform):
if quantize:
logger.warning(
"Optuna does not support both quantization and "
"sampling from LogUniform. Dropped quantization.")
return param.suggest_int(
par, domain.lower, domain.upper, log=True)
return ot.distributions.IntLogUniformDistribution(
domain.lower, domain.upper - 1, step=quantize or 1)
elif isinstance(sampler, Uniform):
return param.suggest_int(
par, domain.lower, domain.upper, step=quantize or 1)
# Upper bound should be inclusive for quantization and
# exclusive otherwise
return ot.distributions.IntUniformDistribution(
domain.lower,
domain.upper - int(bool(not quantize)),
step=quantize or 1)
elif isinstance(domain, Categorical):
if isinstance(sampler, Uniform):
return param.suggest_categorical(par, domain.categories)
return ot.distributions.CategoricalDistribution(
domain.categories)
raise ValueError(
"Optuna search does not support parameters of type "
@ -612,9 +697,9 @@ class OptunaSearch(Searcher):
type(domain.sampler).__name__))
# Parameter name is e.g. "a/b/c" for nested dicts
values = [
resolve_value("/".join(path), domain)
values = {
"/".join(path): resolve_value(domain)
for path, domain in domain_vars
]
}
return values
return values

View File

@ -44,7 +44,7 @@ print(analysis.best_config) # the best config
* Example for using ray tune's API:
```python
# require: pip install flaml[blendsearch] ray[tune]
# require: pip install flaml[blendsearch,ray]
from ray import tune as raytune
from flaml import CFO, BlendSearch
import time
@ -60,18 +60,37 @@ def evaluate_config(config):
# use tune.report to report the metric to optimize
tune.report(metric=metric)
analysis = raytune.run(
evaluate_config, # the function to evaluate a config
config={
# provide a time budget (in seconds) for the tuning process
time_budget_s = 60
# provide the search space
config_search_space = {
'x': tune.lograndint(lower=1, upper=100000),
'y': tune.randint(lower=1, upper=100000)
}, # the search space
}
# provide the low cost partial config
low_cost_partial_config={'x':1}
# set up CFO
search_alg_cfo = CFO(low_cost_partial_config=low_cost_partial_config)
# set up BlendSearch.
search_alg_blendsearch = BlendSearch(metric="metric",
mode="min",
space=config_search_space,
low_cost_partial_config=low_cost_partial_config)
# NOTE that when using BlendSearch as a search_alg in ray tune, you need to
# configure the 'time_budget_s' for BlendSearch accordingly as follows such that BlendSearch is aware of the time budget. This step is not needed when BlendSearch is used as the search_alg in flaml.tune as it is already done automatically in flaml.
search_alg_blendsearch.set_search_properties(config={"time_budget_s": time_budget_s})
analysis = raytune.run(
evaluate_config, # the function to evaluate a config
config=config_search_space,
metric='metric', # the name of the metric used for optimization
mode='min', # the optimization mode, 'min' or 'max'
num_samples=-1, # the maximal number of configs to try, -1 means infinite
time_budget_s=60, # the time budget in seconds
time_budget_s=time_budget_s, # the time budget in seconds
local_dir='logs/', # the local directory to store logs
search_alg=CFO(low_cost_partial_config=[{'x':1}]) # or BlendSearch
search_alg=search_alg_blendsearch # or search_alg_cfo
)
print(analysis.best_trial.last_result) # the best trial's result

View File

@ -1 +1 @@
__version__ = "0.5.8"
__version__ = "0.5.9"

File diff suppressed because one or more lines are too long

View File

@ -48,7 +48,7 @@ setuptools.setup(
"coverage>=5.3",
"xgboost<1.3",
"rgf-python",
"optuna==2.3.0",
"optuna==2.8.0",
"vowpalwabbit",
"openml",
"transformers==4.4.1",
@ -58,10 +58,10 @@ setuptools.setup(
"azure-storage-blob",
],
"blendsearch": [
"optuna==2.3.0"
"optuna==2.8.0"
],
"ray": [
"ray[tune]==1.2.0",
"ray[tune]==1.4.1",
"pyyaml<5.3.1",
],
"azureml": [
@ -74,7 +74,7 @@ setuptools.setup(
"vowpalwabbit",
],
"nlp": [
"ray[tune]>=1.2.0",
"ray[tune]>=1.4.1",
"transformers",
"datasets==1.4.1",
"tensorboardX<=2.2",