mirror of https://github.com/microsoft/autogen.git
Add ChaCha (#92)
* pickle the AutoML object * get best model per estimator * test deberta * stateless API * pickle the AutoML object * get best model per estimator * test deberta * stateless API * prevent divide by zero * test roberta * BlendSearchTuner * sync * version number * update gitignore * delta time * reindex columns when dropping int-indexed columns * add seed * add seed in Args * merge * init upload of ChaCha * remove redundancy * add back catboost * improve AutoVW API * set min_resource_lease in VWOnlineTrial * docstr * rename * docstr * add docstr * improve API and documentation * fix name * docstr * naming * remove max_resource in scheduler * add TODO in flow2 * remove redundancy in rearcher * add input type * adapt code from ray.tune * move files * naming * documentation * fix import error * fix format issues * remove cb in worse than test * improve _generate_all_comb * remove ray tune * naming * VowpalWabbitTrial * import error * import error * merge test code * scheduler import * fix import * remove * import, minor bug and version * Float or Categorical * fix default * add test_autovw.py * add vowpalwabbit and openml * lint * reorg * lint * indent * add autovw notebook * update notebook * update log msg and autovw notebook * update autovw notebook * update autovw notebook * add available strings for model_select_policy * string for metric * Update vw format in flaml/onlineml/trial.py Co-authored-by: olgavrou <olgavrou@gmail.com> * make init_config optional * add _setup_trial_runner and update notebook * space Co-authored-by: Chi Wang (MSR) <chiw@microsoft.com> Co-authored-by: Chi Wang <wang.chi@microsoft.com> Co-authored-by: Qingyun Wu <qiw@microsoft.com> Co-authored-by: olgavrou <olgavrou@gmail.com>
This commit is contained in:
parent
61d1263dfd
commit
0d3a0bfab6
|
@ -5,8 +5,8 @@ This repository incorporates material as listed below or described in the code.
|
|||
#
|
||||
## Component. Ray.
|
||||
|
||||
Code in tune/[analysis.py, sample.py, trial.py] and
|
||||
searcher/[suggestion.py, variant_generator.py] is adapted from
|
||||
Code in tune/[analysis.py, sample.py, trial.py, result.py],
|
||||
searcher/[suggestion.py, variant_generator.py], and scheduler/trial_scheduler.py is adapted from
|
||||
https://github.com/ray-project/ray/blob/master/python/ray/tune/
|
||||
|
||||
|
||||
|
|
|
@ -141,6 +141,8 @@ For more technical details, please check our papers.
|
|||
* [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021.
|
||||
* [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021.
|
||||
|
||||
* ChaCha for online AutoML. Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. To appear in ICML 2021.
|
||||
|
||||
## Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
from flaml.searcher import CFO, BlendSearch, FLOW2, BlendSearchTuner
|
||||
from flaml.automl import AutoML, logger_formatter
|
||||
try:
|
||||
from flaml.onlineml.autovw import AutoVW
|
||||
except ImportError:
|
||||
print('need to install vowpalwabbit to use AutoVW')
|
||||
from flaml.version import __version__
|
||||
import logging
|
||||
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .trial import VowpalWabbitTrial
|
||||
from .trial_runner import OnlineTrialRunner
|
|
@ -0,0 +1,188 @@
|
|||
import numpy as np
|
||||
from typing import Optional, Union
|
||||
import logging
|
||||
from flaml.tune import Trial, Categorical, Float, PolynomialExpansionSet, polynomial_expansion_set
|
||||
from flaml.onlineml import OnlineTrialRunner
|
||||
from flaml.scheduler import ChaChaScheduler
|
||||
from flaml.searcher import ChampionFrontierSearcher
|
||||
from flaml.onlineml.trial import get_ns_feature_dim_from_vw_example
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoVW:
|
||||
"""The AutoML class
|
||||
|
||||
Methods:
|
||||
predict(data_sample)
|
||||
learn(data_sample)
|
||||
AUTO
|
||||
"""
|
||||
WARMSTART_NUM = 100
|
||||
AUTO_STRING = '_auto'
|
||||
VW_INTERACTION_ARG_NAME = 'interactions'
|
||||
|
||||
def __init__(self,
|
||||
max_live_model_num: int,
|
||||
search_space: dict,
|
||||
init_config: Optional[dict] = {},
|
||||
min_resource_lease: Optional[Union[str, float]] = 'auto',
|
||||
automl_runner_args: Optional[dict] = {},
|
||||
scheduler_args: Optional[dict] = {},
|
||||
model_select_policy: Optional[str] = 'threshold_loss_ucb',
|
||||
metric: Optional[str] = 'mae_clipped',
|
||||
random_seed: Optional[int] = None,
|
||||
model_selection_mode: Optional[str] = 'min',
|
||||
cb_coef: Optional[float] = None,
|
||||
):
|
||||
"""Constructor
|
||||
|
||||
Args:
|
||||
max_live_model_num: The maximum number of 'live' models, which, in other words,
|
||||
is the maximum number of models allowed to update in each learning iteraction.
|
||||
search_space: A dictionary of the search space. This search space includes both
|
||||
hyperparameters we want to tune and fixed hyperparameters. In the latter case,
|
||||
the value is a fixed value.
|
||||
init_config: A dictionary of a partial or full initial config,
|
||||
e.g. {'interactions': set(), 'learning_rate': 0.5}
|
||||
min_resource_lease: The minimum resource lease assigned to a particular model/trial.
|
||||
If set as 'auto', it will be calculated automatically.
|
||||
automl_runner_args: A dictionary of configuration for the OnlineTrialRunner.
|
||||
If set {}, default values will be used, which is equivalent to using the following configs.
|
||||
automl_runner_args =
|
||||
{"champion_test_policy": 'loss_ucb' # specifcies how to do the statistic test for a better champion
|
||||
"remove_worse": False # specifcies whether to do worse than test
|
||||
}
|
||||
scheduler_args: A dictionary of configuration for the scheduler.
|
||||
If set {}, default values will be used, which is equivalent to using the following configs.
|
||||
scheduler_args =
|
||||
{"keep_challenger_metric": 'ucb' # what metric to use when deciding the top performing challengers
|
||||
"keep_challenger_ratio": 0.5 # denotes the ratio of top performing challengers to keep live
|
||||
"keep_champion": True # specifcies whether to keep the champion always running
|
||||
}
|
||||
model_select_policy: A string in ['threshold_loss_ucb', 'threshold_loss_lcb', 'threshold_loss_avg',
|
||||
'loss_ucb', 'loss_lcb', 'loss_avg'] to specify how to select one model to do prediction
|
||||
from the live model pool. Default value is 'threshold_loss_ucb'.
|
||||
metric: A string in ['mae_clipped', 'mae', 'mse', 'absolute_clipped', 'absolute', 'squared']
|
||||
to specify the name of the loss function used for calculating the progressive validation loss in ChaCha.
|
||||
random_seed (int): An integer of the random seed used in the searcher
|
||||
(more specifically this the random seed for ConfigOracle)
|
||||
model_selection_mode: A string in ['min', 'max'] to specify the objective as
|
||||
minimization or maximization.
|
||||
cb_coef (float): A float coefficient (optional) used in the sample complexity bound.
|
||||
"""
|
||||
self._max_live_model_num = max_live_model_num
|
||||
self._search_space = search_space
|
||||
self._init_config = init_config
|
||||
self._online_trial_args = {"metric": metric,
|
||||
"min_resource_lease": min_resource_lease,
|
||||
"cb_coef": cb_coef,
|
||||
}
|
||||
self._automl_runner_args = automl_runner_args
|
||||
self._scheduler_args = scheduler_args
|
||||
self._model_select_policy = model_select_policy
|
||||
self._model_selection_mode = model_selection_mode
|
||||
self._random_seed = random_seed
|
||||
self._trial_runner = None
|
||||
self._best_trial = None
|
||||
# code for debugging purpose
|
||||
self._prediction_trial_id = None
|
||||
self._iter = 0
|
||||
|
||||
def _setup_trial_runner(self, vw_example):
|
||||
"""Set up the _trial_runner based on one vw_example
|
||||
"""
|
||||
# setup the default search space for the namespace interaction hyperparameter
|
||||
search_space = self._search_space.copy()
|
||||
for k, v in self._search_space.items():
|
||||
if k == self.VW_INTERACTION_ARG_NAME and v == self.AUTO_STRING:
|
||||
raw_namespaces = self.get_ns_feature_dim_from_vw_example(vw_example).keys()
|
||||
search_space[k] = polynomial_expansion_set(init_monomials=set(raw_namespaces))
|
||||
# setup the init config based on the input _init_config and search space
|
||||
init_config = self._init_config.copy()
|
||||
for k, v in search_space.items():
|
||||
if k not in init_config.keys():
|
||||
if isinstance(v, PolynomialExpansionSet):
|
||||
init_config[k] = set()
|
||||
elif (not isinstance(v, Categorical) and not isinstance(v, Float)):
|
||||
init_config[k] = v
|
||||
searcher_args = {"init_config": init_config,
|
||||
"space": search_space,
|
||||
"random_seed": self._random_seed,
|
||||
'online_trial_args': self._online_trial_args,
|
||||
}
|
||||
logger.info("original search_space %s", self._search_space)
|
||||
logger.info("original init_config %s", self._init_config)
|
||||
logger.info('searcher_args %s', searcher_args)
|
||||
logger.info('scheduler_args %s', self._scheduler_args)
|
||||
logger.info('automl_runner_args %s', self._automl_runner_args)
|
||||
searcher = ChampionFrontierSearcher(**searcher_args)
|
||||
scheduler = ChaChaScheduler(**self._scheduler_args)
|
||||
self._trial_runner = OnlineTrialRunner(max_live_model_num=self._max_live_model_num,
|
||||
searcher=searcher,
|
||||
scheduler=scheduler,
|
||||
**self._automl_runner_args)
|
||||
|
||||
def predict(self, data_sample):
|
||||
"""Predict on the input example (e.g., vw example)
|
||||
|
||||
Args:
|
||||
data_sample (vw_example)
|
||||
"""
|
||||
if self._trial_runner is None:
|
||||
self._setup_trial_runner(data_sample)
|
||||
self._best_trial = self._select_best_trial()
|
||||
self._y_predict = self._best_trial.predict(data_sample)
|
||||
# code for debugging purpose
|
||||
if self._prediction_trial_id is None or \
|
||||
self._prediction_trial_id != self._best_trial.trial_id:
|
||||
self._prediction_trial_id = self._best_trial.trial_id
|
||||
logger.info('prediction trial id changed to %s at iter %s, resource used: %s',
|
||||
self._prediction_trial_id, self._iter,
|
||||
self._best_trial.result.resource_used)
|
||||
return self._y_predict
|
||||
|
||||
def learn(self, data_sample):
|
||||
"""Perform one online learning step with the given data sample
|
||||
|
||||
Args:
|
||||
data_sample (vw_example): one data sample on which the model gets updated
|
||||
"""
|
||||
self._iter += 1
|
||||
self._trial_runner.step(data_sample, (self._y_predict, self._best_trial))
|
||||
|
||||
def _select_best_trial(self):
|
||||
"""Select a best trial from the running trials accoring to the _model_select_policy
|
||||
"""
|
||||
best_score = float('+inf') if self._model_selection_mode == 'min' else float('-inf')
|
||||
new_best_trial = None
|
||||
for trial in self._trial_runner.running_trials:
|
||||
if trial.result is not None and ('threshold' not in self._model_select_policy
|
||||
or trial.result.resource_used >= self.WARMSTART_NUM):
|
||||
score = trial.result.get_score(self._model_select_policy)
|
||||
if ('min' == self._model_selection_mode and score < best_score) or \
|
||||
('max' == self._model_selection_mode and score > best_score):
|
||||
best_score = score
|
||||
new_best_trial = trial
|
||||
if new_best_trial is not None:
|
||||
logger.debug('best_trial resource used: %s', new_best_trial.result.resource_used)
|
||||
return new_best_trial
|
||||
else:
|
||||
# This branch will be triggered when the resource consumption all trials are smaller
|
||||
# than the WARMSTART_NUM threshold. In this case, we will select the _best_trial
|
||||
# selected in the previous iteration.
|
||||
if self._best_trial is not None and self._best_trial.status == Trial.RUNNING:
|
||||
logger.debug('old best trial %s', self._best_trial.trial_id)
|
||||
return self._best_trial
|
||||
else:
|
||||
# this will be triggered in the first iteration or in the iteration where we want
|
||||
# to select the trial from the previous iteration but that trial has been paused
|
||||
# (i.e., self._best_trial.status != Trial.RUNNING) by the scheduler.
|
||||
logger.debug('using champion trial: %s',
|
||||
self._trial_runner.champion_trial.trial_id)
|
||||
return self._trial_runner.champion_trial
|
||||
|
||||
@staticmethod
|
||||
def get_ns_feature_dim_from_vw_example(vw_example) -> dict:
|
||||
"""Get a dictionary of feature dimensionality for each namespace singleton
|
||||
"""
|
||||
return get_ns_feature_dim_from_vw_example(vw_example)
|
|
@ -0,0 +1,432 @@
|
|||
import numpy as np
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import copy
|
||||
import collections
|
||||
from typing import Dict, Optional
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
||||
from vowpalwabbit import pyvw
|
||||
from flaml.tune import Trial
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_ns_feature_dim_from_vw_example(vw_example) -> dict:
|
||||
"""Get a dictionary of feature dimensionality for each namespace singleton
|
||||
|
||||
NOTE:
|
||||
Assumption: assume the vw_example takes one of the following format
|
||||
depending on whether the example includes the feature names
|
||||
|
||||
format 1: 'y |ns1 feature1:feature_value1 feature2:feature_value2 |ns2
|
||||
ns2 feature3:feature_value3 feature4:feature_value4'
|
||||
format 2: 'y | ns1 feature_value1 feature_value2 |
|
||||
ns2 feature_value3 feature_value4'
|
||||
|
||||
The output of both cases are {'ns1': 2, 'ns2': 2}
|
||||
|
||||
For more information about the input formate of vw example, please refer to
|
||||
https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format
|
||||
"""
|
||||
ns_feature_dim = {}
|
||||
data = vw_example.split('|')
|
||||
for i in range(1, len(data)):
|
||||
if ':' in data[i]:
|
||||
ns_w_feature = data[i].split(' ')
|
||||
ns = ns_w_feature[0]
|
||||
feature = ns_w_feature[1:]
|
||||
feature_dim = len(feature)
|
||||
else:
|
||||
data_split = data[i].split(' ')
|
||||
ns = data_split[0]
|
||||
feature_dim = len(data_split) - 1
|
||||
if len(data_split[-1]) == 0:
|
||||
feature_dim -= 1
|
||||
ns_feature_dim[ns] = feature_dim
|
||||
logger.debug('name space feature dimension %s', ns_feature_dim)
|
||||
return ns_feature_dim
|
||||
|
||||
|
||||
class OnlineResult:
|
||||
"""Class for managing the result statistics of a trial
|
||||
|
||||
Attributes:
|
||||
observation_count: the total number of observations
|
||||
resource_used: the sum of loss
|
||||
|
||||
Methods:
|
||||
update_result(new_loss, new_resource_used, data_dimension)
|
||||
Update result
|
||||
get_score(score_name)
|
||||
Get the score according to the input score_name
|
||||
"""
|
||||
prob_delta = 0.1
|
||||
LOSS_MIN = 0.0
|
||||
LOSS_MAX = np.inf
|
||||
CB_COEF = 0.05 # 0.001 for mse
|
||||
|
||||
def __init__(self, result_type_name: str, cb_coef: Optional[float] = None,
|
||||
init_loss: Optional[float] = 0.0, init_cb: Optional[float] = 100.0,
|
||||
mode: Optional[str] = 'min', sliding_window_size: Optional[int] = 100):
|
||||
"""
|
||||
Args:
|
||||
result_type_name (str): The name of the result type
|
||||
"""
|
||||
self._result_type_name = result_type_name # for example 'mse' or 'mae'
|
||||
self._mode = mode
|
||||
self._init_loss = init_loss
|
||||
# statistics needed for alg
|
||||
self.observation_count = 0
|
||||
self.resource_used = 0.0
|
||||
self._loss_avg = 0.0
|
||||
self._loss_cb = init_cb # a large number (TODO: this can be changed)
|
||||
self._cb_coef = cb_coef if cb_coef is not None else self.CB_COEF
|
||||
# optional statistics
|
||||
self._sliding_window_size = sliding_window_size
|
||||
self._loss_queue = collections.deque(maxlen=self._sliding_window_size)
|
||||
|
||||
def update_result(self, new_loss, new_resource_used, data_dimension,
|
||||
bound_of_range=1.0, new_observation_count=1.0):
|
||||
"""Update result statistics
|
||||
"""
|
||||
self.resource_used += new_resource_used
|
||||
# keep the running average instead of sum of loss to avoid over overflow
|
||||
self._loss_avg = self._loss_avg * (self.observation_count / (self.observation_count + new_observation_count)
|
||||
) + new_loss / (self.observation_count + new_observation_count)
|
||||
self.observation_count += new_observation_count
|
||||
self._loss_cb = self._update_loss_cb(bound_of_range, data_dimension)
|
||||
self._loss_queue.append(new_loss)
|
||||
|
||||
def _update_loss_cb(self, bound_of_range, data_dim,
|
||||
bound_name='sample_complexity_bound'):
|
||||
"""Calculate bound coef
|
||||
"""
|
||||
if bound_name == 'sample_complexity_bound':
|
||||
# set the coefficient in the loss bound
|
||||
if 'mae' in self.result_type_name:
|
||||
coef = self._cb_coef * bound_of_range
|
||||
else:
|
||||
coef = 0.001 * bound_of_range
|
||||
|
||||
comp_F = math.sqrt(data_dim)
|
||||
n = self.observation_count
|
||||
return coef * comp_F * math.sqrt((np.log10(n / OnlineResult.prob_delta)) / n)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def result_type_name(self):
|
||||
return self._result_type_name
|
||||
|
||||
@property
|
||||
def loss_avg(self):
|
||||
return self._loss_avg if \
|
||||
self.observation_count != 0 else self._init_loss
|
||||
|
||||
@property
|
||||
def loss_cb(self):
|
||||
return self._loss_cb
|
||||
|
||||
@property
|
||||
def loss_lcb(self):
|
||||
return max(self._loss_avg - self._loss_cb, OnlineResult.LOSS_MIN)
|
||||
|
||||
@property
|
||||
def loss_ucb(self):
|
||||
return min(self._loss_avg + self._loss_cb, OnlineResult.LOSS_MAX)
|
||||
|
||||
@property
|
||||
def loss_avg_recent(self):
|
||||
return sum(self._loss_queue) / len(self._loss_queue) \
|
||||
if len(self._loss_queue) != 0 else self._init_loss
|
||||
|
||||
def get_score(self, score_name, cb_ratio=1):
|
||||
if 'lcb' in score_name:
|
||||
return max(self._loss_avg - cb_ratio * self._loss_cb, OnlineResult.LOSS_MIN)
|
||||
elif 'ucb' in score_name:
|
||||
return min(self._loss_avg + cb_ratio * self._loss_cb, OnlineResult.LOSS_MAX)
|
||||
elif 'avg' in score_name:
|
||||
return self._loss_avg
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseOnlineTrial(Trial):
|
||||
"""Class for online trial.
|
||||
|
||||
Attributes:
|
||||
config: the config for this trial
|
||||
trial_id: the trial_id of this trial
|
||||
min_resource_lease (float): the minimum resource realse
|
||||
status: the status of this trial
|
||||
start_time: the start time of this trial
|
||||
custom_trial_name: a custom name for this trial
|
||||
|
||||
Methods:
|
||||
set_resource_lease(resource)
|
||||
set_status(status)
|
||||
set_checked_under_current_champion(checked_under_current_champion)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: dict,
|
||||
min_resource_lease: float,
|
||||
is_champion: Optional[bool] = False,
|
||||
is_checked_under_current_champion: Optional[bool] = True,
|
||||
custom_trial_name: Optional[str] = 'mae',
|
||||
trial_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
config: the config dict
|
||||
min_resource_lease: the minimum resource realse
|
||||
is_champion: a bool variable
|
||||
is_checked_under_current_champion: a bool variable
|
||||
custom_trial_name: custom trial name
|
||||
trial_id: the trial id
|
||||
"""
|
||||
# ****basic variables
|
||||
self.config = config
|
||||
self.trial_id = trial_id
|
||||
self.status = Trial.PENDING
|
||||
self.start_time = time.time()
|
||||
self.custom_trial_name = custom_trial_name
|
||||
|
||||
# ***resource budget related variable
|
||||
self._min_resource_lease = min_resource_lease
|
||||
self._resource_lease = copy.copy(self._min_resource_lease)
|
||||
# ***champion related variables
|
||||
self._is_champion = is_champion
|
||||
# self._is_checked_under_current_champion_ is supposed to be always 1 when the trial is first created
|
||||
self._is_checked_under_current_champion = is_checked_under_current_champion
|
||||
|
||||
@property
|
||||
def is_champion(self):
|
||||
return self._is_champion
|
||||
|
||||
@property
|
||||
def is_checked_under_current_champion(self):
|
||||
return self._is_checked_under_current_champion
|
||||
|
||||
@property
|
||||
def resource_lease(self):
|
||||
return self._resource_lease
|
||||
|
||||
def set_checked_under_current_champion(self, checked_under_current_champion: bool):
|
||||
"""TODO: add documentation why this is needed. This is needed because sometimes
|
||||
we want to know whether a trial has been paused since a new champion is promoted.
|
||||
We want to try to pause those running trials (even though they are not yet achieve
|
||||
the next scheduling check point according to resource used and resource lease),
|
||||
because a better trial is likely to be in the new challengers generated by the new
|
||||
champion, so we want to try them as soon as possible.
|
||||
If we wait until we reach the next scheduling point, we may waste a lot of resource
|
||||
(depending on what is the current resource lease) on the old trials (note that new
|
||||
trials is not possible to be scheduled to run until there is a slot openning).
|
||||
Intuitively speaking, we want to squize an opening slot as soon as possible once
|
||||
a new champion is promoted, such that we are able to try newly generated challengers.
|
||||
"""
|
||||
self._is_checked_under_current_champion = checked_under_current_champion
|
||||
|
||||
def set_resource_lease(self, resource: float):
|
||||
self._resource_lease = resource
|
||||
|
||||
def set_status(self, status):
|
||||
"""Sets the status of the trial and record the start time
|
||||
"""
|
||||
self.status = status
|
||||
if status == Trial.RUNNING:
|
||||
if self.start_time is None:
|
||||
self.start_time = time.time()
|
||||
|
||||
|
||||
class VowpalWabbitTrial(BaseOnlineTrial):
|
||||
"""Implement BaseOnlineTrial for Vowpal Wabbit
|
||||
|
||||
Attributes:
|
||||
model: the online model
|
||||
result: the anytime result for the online model
|
||||
trainable_class: the model class (set as pyvw.vw for VowpalWabbitTrial)
|
||||
|
||||
config: the config for this trial
|
||||
trial_id: the trial_id of this trial
|
||||
min_resource_lease (float): the minimum resource realse
|
||||
status: the status of this trial
|
||||
start_time: the start time of this trial
|
||||
custom_trial_name: a custom name for this trial
|
||||
|
||||
Methods:
|
||||
set_resource_lease(resource)
|
||||
set_status(status)
|
||||
set_checked_under_current_champion(checked_under_current_champion)
|
||||
|
||||
NOTE:
|
||||
About result:
|
||||
1. training related results (need to be updated in the trainable class)
|
||||
2. result about resources lease (need to be updated externally)
|
||||
|
||||
About namespaces in vw:
|
||||
- Wiki in vw:
|
||||
https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Namespaces
|
||||
- Namespace vs features:
|
||||
https://stackoverflow.com/questions/28586225/in-vowpal-wabbit-what-is-the-difference-between-a-namespace-and-feature
|
||||
"""
|
||||
MODEL_CLASS = pyvw.vw
|
||||
cost_unit = 1.0
|
||||
interactions_config_key = 'interactions'
|
||||
MIN_RES_CONST = 5
|
||||
|
||||
def __init__(self,
|
||||
config: dict,
|
||||
min_resource_lease: float,
|
||||
metric: str = 'mae',
|
||||
is_champion: Optional[bool] = False,
|
||||
is_checked_under_current_champion: Optional[bool] = True,
|
||||
custom_trial_name: Optional[str] = 'vw_mae_clipped',
|
||||
trial_id: Optional[str] = None,
|
||||
cb_coef: Optional[float] = None,
|
||||
):
|
||||
"""Constructor
|
||||
|
||||
Args:
|
||||
config (dict): the config of the trial (note that the config is a set
|
||||
because the hyperparameters are )
|
||||
min_resource_lease (float): the minimum resource lease
|
||||
metric (str): the loss metric
|
||||
is_champion (bool): indicates whether the trial is the current champion or not
|
||||
is_checked_under_current_champion (bool): indicates whether this trials has
|
||||
been paused under the current champion
|
||||
trial_id (str): id of the trial (if None, it will be generated in the constructor)
|
||||
|
||||
"""
|
||||
# attributes
|
||||
self.trial_id = self._config_to_id(config) if trial_id is None else trial_id
|
||||
logger.info('Create trial with trial_id: %s', self.trial_id)
|
||||
super().__init__(config, min_resource_lease, is_champion, is_checked_under_current_champion,
|
||||
custom_trial_name, self.trial_id)
|
||||
self.model = None # model is None until the config is scheduled to run
|
||||
self.result = None
|
||||
self.trainable_class = self.MODEL_CLASS
|
||||
# variables that are needed during online training
|
||||
self._metric = metric
|
||||
self._y_min_observed = None
|
||||
self._y_max_observed = None
|
||||
# application dependent variables
|
||||
self._dim = None
|
||||
self._cb_coef = cb_coef
|
||||
|
||||
@staticmethod
|
||||
def _config_to_id(config):
|
||||
"""Generate an id for the provided config
|
||||
"""
|
||||
# sort config keys
|
||||
sorted_k_list = sorted(list(config.keys()))
|
||||
config_id_full = ''
|
||||
for key in sorted_k_list:
|
||||
v = config[key]
|
||||
config_id = '|'
|
||||
if isinstance(v, set):
|
||||
value_list = sorted(v)
|
||||
config_id += '_'.join([str(k) for k in value_list])
|
||||
else:
|
||||
config_id += str(v)
|
||||
config_id_full = config_id_full + config_id
|
||||
return config_id_full
|
||||
|
||||
def _initialize_vw_model(self, vw_example):
|
||||
"""Initialize a vw model using the trainable_class
|
||||
"""
|
||||
self._vw_config = self.config.copy()
|
||||
ns_interactions = self.config.get(VowpalWabbitTrial.interactions_config_key, None)
|
||||
# ensure the feature interaction config is a list (required by VW)
|
||||
if ns_interactions is not None:
|
||||
self._vw_config[VowpalWabbitTrial.interactions_config_key] \
|
||||
= list(ns_interactions)
|
||||
# get the dimensionality of the feature according to the namespace configuration
|
||||
namespace_feature_dim = get_ns_feature_dim_from_vw_example(vw_example)
|
||||
self._dim = self._get_dim_from_ns(namespace_feature_dim, ns_interactions)
|
||||
# construct an instance of vw model using the input config and fixed config
|
||||
self.model = self.trainable_class(**self._vw_config)
|
||||
self.result = OnlineResult(self._metric,
|
||||
cb_coef=self._cb_coef,
|
||||
init_loss=0.0, init_cb=100.0,)
|
||||
|
||||
def train_eval_model_online(self, data_sample, y_pred):
|
||||
"""Train and eval model online
|
||||
"""
|
||||
# extract info needed the first time we see the data
|
||||
if self._resource_lease == 'auto' or self._resource_lease is None:
|
||||
assert self._dim is not None
|
||||
self._resource_lease = self._dim * self.MIN_RES_CONST
|
||||
y = self._get_y_from_vw_example(data_sample)
|
||||
self._update_y_range(y)
|
||||
if self.model is None:
|
||||
# initialize self.model and self.result
|
||||
self._initialize_vw_model(data_sample)
|
||||
# do one step of learning
|
||||
self.model.learn(data_sample)
|
||||
# update training related results accordingly
|
||||
new_loss = self._get_loss(y, y_pred, self._metric,
|
||||
self._y_min_observed, self._y_max_observed)
|
||||
# udpate sample size, sum of loss, and cost
|
||||
data_sample_size = 1
|
||||
bound_of_range = self._y_max_observed - self._y_min_observed
|
||||
if bound_of_range == 0:
|
||||
bound_of_range = 1.0
|
||||
self.result.update_result(new_loss,
|
||||
VowpalWabbitTrial.cost_unit * data_sample_size,
|
||||
self._dim, bound_of_range)
|
||||
|
||||
def predict(self, x):
|
||||
"""Predict using the model
|
||||
"""
|
||||
if self.model is None:
|
||||
# initialize self.model and self.result
|
||||
self._initialize_vw_model(x)
|
||||
return self.model.predict(x)
|
||||
|
||||
def _get_loss(self, y_true, y_pred, loss_func_name, y_min_observed, y_max_observed):
|
||||
"""Get instantaneous loss from y_true and y_pred, and loss_func_name
|
||||
For mae_clip, we clip y_pred in the observed range of y
|
||||
"""
|
||||
if 'mse' in loss_func_name or 'squared' in loss_func_name:
|
||||
loss_func = mean_squared_error
|
||||
elif 'mae' in loss_func_name or 'absolute' in loss_func_name:
|
||||
loss_func = mean_absolute_error
|
||||
if y_min_observed is not None and y_max_observed is not None and \
|
||||
'clip' in loss_func_name:
|
||||
# clip y_pred in the observed range of y
|
||||
y_pred = min(y_max_observed, max(y_pred, y_min_observed))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss_func([y_true], [y_pred])
|
||||
|
||||
def _update_y_range(self, y):
|
||||
"""Maintain running observed minimum and maximum target value
|
||||
"""
|
||||
if self._y_min_observed is None or y < self._y_min_observed:
|
||||
self._y_min_observed = y
|
||||
if self._y_max_observed is None or y > self._y_max_observed:
|
||||
self._y_max_observed = y
|
||||
|
||||
@staticmethod
|
||||
def _get_dim_from_ns(namespace_feature_dim: dict, namespace_interactions: [set, list]):
|
||||
"""Get the dimensionality of the corresponding feature of input namespace set
|
||||
"""
|
||||
total_dim = sum(namespace_feature_dim.values())
|
||||
if namespace_interactions:
|
||||
for f in namespace_interactions:
|
||||
ns_dim = 1.0
|
||||
for c in f:
|
||||
ns_dim *= namespace_feature_dim[c]
|
||||
total_dim += ns_dim
|
||||
return total_dim
|
||||
|
||||
def clean_up_model(self):
|
||||
self.model = None
|
||||
self.result = None
|
||||
|
||||
@staticmethod
|
||||
def _get_y_from_vw_example(vw_example):
|
||||
"""Get y from a vw_example. this works for regression datasets.
|
||||
"""
|
||||
return float(vw_example.split('|')[0])
|
|
@ -0,0 +1,495 @@
|
|||
import time
|
||||
import numpy as np
|
||||
import math
|
||||
from flaml.tune import Trial
|
||||
from flaml.scheduler import TrialScheduler
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OnlineTrialRunner:
|
||||
"""The OnlineTrialRunner class
|
||||
|
||||
Methods:
|
||||
step(max_live_model_num, data_sample, prediction_trial_tuple)
|
||||
Outputs a _max_live_model_num number of trials to run each time it is called
|
||||
get_top_running_trials()
|
||||
Get a list of trial ids, whose performance is among the top running trials
|
||||
add_trial(trial)
|
||||
Add trial to this TrialRunner.
|
||||
stop_trial(trial)
|
||||
Set the status of a trial to be Trial.TERMINATED and perform other subsequent operations
|
||||
pause_trial(trial)
|
||||
Set the status of a trial to be Trial.PAUSED and perform other subsequent operations
|
||||
run_trial(trial)
|
||||
Set the status of a trial to be Trial.RUNNING and perform other subsequent operations
|
||||
get_trials()
|
||||
Get all the trials added (whatever that status) in the the OnlineTrialRunner
|
||||
|
||||
NOTE about the status of a trial:
|
||||
Trial.PENDING: All trials are set to be pending when frist added into the OnlineTrialRunner until
|
||||
it is selected to run. By this definition, a trial with status Trial.PENDING is a challenger
|
||||
trial added to the OnlineTrialRunner but never been selected to run.
|
||||
It denotes the starting of trial's lifespan in the OnlineTrialRunner.
|
||||
Trial.RUNNING: It indicates that this trial is one of the concurrently running trials.
|
||||
The max number of Trial.RUNNING trials is running_budget.
|
||||
The status of a trial will be set to Trial.RUNNING the next time it selected to run.
|
||||
A trial's status may have the following change:
|
||||
Trial.PENDING -> Trial.RUNNING
|
||||
Trial.PAUSED - > Trial.RUNNING
|
||||
Trial.PAUSED: The status of a trial is set to Trial.PAUSED once it is removed from the running trials.
|
||||
Trial.RUNNING - > Trial.PAUSED
|
||||
Trial.TERMINATED: set the status of a trial to Trial.TERMINATED when you never want to select it.
|
||||
It denotes the real end of a trial's lifespan.
|
||||
Status change routine of a trial
|
||||
Trial.PENDING -> (Trial.RUNNING -> Trial.PAUSED -> Trial.RUNNING -> ...) -> Trial.TERMINATED(optional)
|
||||
"""
|
||||
RANDOM_SEED = 123456
|
||||
WARMSTART_NUM = 100
|
||||
|
||||
def __init__(self,
|
||||
max_live_model_num: int,
|
||||
searcher=None,
|
||||
scheduler=None,
|
||||
champion_test_policy='loss_ucb',
|
||||
**kwargs
|
||||
):
|
||||
"""Constructor
|
||||
|
||||
Args:
|
||||
max_live_model_num: The maximum number of 'live'/running models allowed.
|
||||
searcher: A class for generating Trial objects progressively. The ConfigOracle
|
||||
is implemented in the searcher.
|
||||
Required methods of the searcher:
|
||||
- next_trial()
|
||||
Generate the next trial to add.
|
||||
- set_search_properties(metric: Optional[str], mode: Optional[str], config: dict)
|
||||
Generate new challengers based on the current champion and update the challenger list
|
||||
- on_trial_result(trial_id: str, result: Dict)
|
||||
Reprot results to the scheduler.
|
||||
scheduler: A class for managing the 'live' trials and allocating the resources for the trials.
|
||||
Required methods of the scheduler:
|
||||
- on_trial_add(trial_runner, trial: Trial)
|
||||
It adds candidate trials to the scheduler. It is called inside of the add_trial
|
||||
function in the TrialRunner.
|
||||
- on_trial_remove(trial_runner, trial: Trial)
|
||||
Remove terminated trials from the scheduler.
|
||||
- on_trial_result(trial_runner, trial: Trial, result: Dict)
|
||||
Reprot results to the scheduler.
|
||||
- choose_trial_to_run(trial_runner) -> Optional[Trial]
|
||||
Among them, on_trial_result and choose_trial_to_run are the most important methods
|
||||
champion_test_policy: A string to specify what test policy to test for champion.
|
||||
Currently can choose from ['loss_ucb', 'loss_avg', 'loss_lcb', None].
|
||||
"""
|
||||
# OnlineTrialRunner setting
|
||||
self._searcher = searcher
|
||||
self._scheduler = scheduler
|
||||
self._champion_test_policy = champion_test_policy
|
||||
self._max_live_model_num = max_live_model_num
|
||||
self._remove_worse = kwargs.get('remove_worse', True)
|
||||
self._bound_trial_num = kwargs.get('bound_trial_num', False)
|
||||
self._no_model_persistence = True
|
||||
|
||||
# stores all the trials added to the OnlineTrialRunner
|
||||
# i.e., include the champion and all the challengers
|
||||
self._trials = []
|
||||
self._champion_trial = None
|
||||
self._best_challenger_trial = None
|
||||
self._first_challenger_pool_size = None
|
||||
self._random_state = np.random.RandomState(self.RANDOM_SEED)
|
||||
self._running_trials = set()
|
||||
|
||||
# initially schedule up to max_live_model_num of live models and
|
||||
# set the first trial as the champion (which is done inside self.step())
|
||||
self._total_steps = 0
|
||||
logger.info('init step %s', self._max_live_model_num)
|
||||
# TODO: add more comments
|
||||
self.step()
|
||||
assert self._champion_trial is not None
|
||||
|
||||
@property
|
||||
def champion_trial(self) -> Trial:
|
||||
"""The champion trial
|
||||
"""
|
||||
return self._champion_trial
|
||||
|
||||
@property
|
||||
def running_trials(self):
|
||||
"""The running/'live' trials
|
||||
"""
|
||||
return self._running_trials
|
||||
|
||||
def step(self, data_sample=None, prediction_trial_tuple=None):
|
||||
"""Schedule up to max_live_model_num trials to run
|
||||
|
||||
Args:
|
||||
data_sample
|
||||
prediction_trial_tuple
|
||||
|
||||
NOTE:
|
||||
It consists of the following several parts:
|
||||
Update model:
|
||||
0. Update running trials using observations received.
|
||||
Tests for Champion
|
||||
1. Test for champion (BetterThan test, and WorseThan test)
|
||||
1.1 BetterThan test
|
||||
1.2 WorseThan test: a trial may be removed if WroseThan test is triggered
|
||||
Online Scheduling:
|
||||
2. Report results to the searcher and scheduler (the scheduler will return a decision about
|
||||
the status of the running trials).
|
||||
3. Pause or stop a trial according to the scheduler's decision.
|
||||
Add trial into the OnlineTrialRunner if there are opening slots.
|
||||
|
||||
TODO:
|
||||
add documentation about the Args
|
||||
"""
|
||||
# ***********Update running trials with observation***************************
|
||||
if data_sample is not None:
|
||||
self._total_steps += 1
|
||||
prediction_made, prediction_trial = prediction_trial_tuple[0], prediction_trial_tuple[1]
|
||||
# assert prediction_trial.status == Trial.RUNNING
|
||||
trials_to_pause = []
|
||||
for trial in list(self._running_trials):
|
||||
if trial != prediction_trial:
|
||||
y_predicted = trial.predict(data_sample)
|
||||
else:
|
||||
y_predicted = prediction_made
|
||||
trial.train_eval_model_online(data_sample, y_predicted)
|
||||
logger.debug('running trial at iter %s %s %s %s %s %s', self._total_steps,
|
||||
trial.trial_id, trial.result.loss_avg, trial.result.loss_cb,
|
||||
trial.result.resource_used, trial.resource_lease)
|
||||
# report result to the searcher
|
||||
self._searcher.on_trial_result(trial.trial_id, trial.result)
|
||||
# report result to the scheduler and the scheduler makes a decision about
|
||||
# the running status of the trial
|
||||
decision = self._scheduler.on_trial_result(self, trial, trial.result)
|
||||
# set the status of the trial according to the decision made by the scheduler
|
||||
logger.debug('trial decision %s %s at step %s', decision, trial.trial_id, self._total_steps)
|
||||
if decision == TrialScheduler.STOP:
|
||||
self.stop_trial(trial)
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
trials_to_pause.append(trial)
|
||||
else:
|
||||
self.run_trial(trial)
|
||||
# ***********Statistical test of champion*************************************
|
||||
self._champion_test()
|
||||
# Pause the trial after the tests because the tests involves the reset of the trial's result
|
||||
for trial in trials_to_pause:
|
||||
self.pause_trial(trial)
|
||||
# ***********Add and schedule new trials to run if there are opening slots****
|
||||
# Add trial if needed: add challengers into consideration through _add_trial_from_searcher()
|
||||
# if there are available slots
|
||||
for _ in range(self._max_live_model_num - len(self._running_trials)):
|
||||
self._add_trial_from_searcher()
|
||||
# Scheduling: schedule up to max_live_model_num number of trials to run
|
||||
# (set the status as Trial.RUNNING)
|
||||
while self._max_live_model_num > len(self._running_trials):
|
||||
trial_to_run = self._scheduler.choose_trial_to_run(self)
|
||||
if trial_to_run is not None:
|
||||
self.run_trial(trial_to_run)
|
||||
else:
|
||||
break
|
||||
|
||||
def get_top_running_trials(self, top_ratio=None, top_metric='ucb') -> list:
|
||||
"""Get a list of trial ids, whose performance is among the top running trials
|
||||
"""
|
||||
running_valid_trials = [trial for trial in self._running_trials if
|
||||
trial.result is not None]
|
||||
if not running_valid_trials:
|
||||
return
|
||||
if top_ratio is None:
|
||||
top_number = 0
|
||||
elif isinstance(top_ratio, float):
|
||||
top_number = math.ceil(len(running_valid_trials) * top_ratio)
|
||||
elif isinstance(top_ratio, str) and 'best' in top_ratio:
|
||||
top_number = 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'ucb' in top_metric:
|
||||
test_attribute = 'loss_ucb'
|
||||
elif 'avg' in top_metric:
|
||||
test_attribute = 'loss_avg'
|
||||
elif 'lcb' in top_metric:
|
||||
test_attribute = 'loss_lcb'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
top_running_valid_trials = []
|
||||
logger.info('Running trial ids %s', [trial.trial_id for trial in running_valid_trials])
|
||||
self._random_state.shuffle(running_valid_trials)
|
||||
results = [trial.result.get_score(test_attribute) for trial in running_valid_trials]
|
||||
sorted_index = np.argsort(np.array(results)) # sorted result (small to large) index
|
||||
for i in range(min(top_number, len(running_valid_trials))):
|
||||
top_running_valid_trials.append(running_valid_trials[sorted_index[i]])
|
||||
logger.info('Top running ids %s', [trial.trial_id for trial in top_running_valid_trials])
|
||||
return top_running_valid_trials
|
||||
|
||||
def _add_trial_from_searcher(self):
|
||||
"""Add a new trial to this TrialRunner.
|
||||
|
||||
NOTE:
|
||||
The new trial is acquired from the input search algorithm, i.e. self._searcher
|
||||
A 'new' trial means the trial is not in self._trial
|
||||
"""
|
||||
# (optionally) upper bound the number of trials in the OnlineTrialRunner
|
||||
if self._bound_trial_num and self._first_challenger_pool_size is not None:
|
||||
active_trial_size = len([t for t in self._trials if t.status != Trial.TERMINATED])
|
||||
trial_num_upper_bound = int(round((np.log10(self._total_steps) + 1) * self._first_challenger_pool_size)
|
||||
) if self._first_challenger_pool_size else np.inf
|
||||
if active_trial_size > trial_num_upper_bound:
|
||||
logger.info('Not adding new trials: %s exceeds trial limit %s.',
|
||||
active_trial_size, trial_num_upper_bound)
|
||||
return None
|
||||
|
||||
# output one trial from the trial pool (new challenger pool) maintained in the searcher
|
||||
# Assumption on the searcher: when all frontiers (i.e., all the challengers generated
|
||||
# based on the current champion) of the current champion are added, calling next_trial()
|
||||
# will return None
|
||||
trial = self._searcher.next_trial()
|
||||
if trial is not None:
|
||||
self.add_trial(trial) # dup checked in add_trial
|
||||
# the champion_trial is initially None, so we need to set it up the first time
|
||||
# a valid trial is added.
|
||||
# Assumption on self._searcher: the first trial generated is the champion trial
|
||||
if self._champion_trial is None:
|
||||
logger.info('Initial set up of the champion trial %s', trial.config)
|
||||
self._set_champion(trial)
|
||||
else:
|
||||
self._all_new_challengers_added = True
|
||||
if self._first_challenger_pool_size is None:
|
||||
self._first_challenger_pool_size = len(self._trials)
|
||||
|
||||
def _champion_test(self):
|
||||
"""Perform tests again the latest champion, including bette_than tests and worse_than tests
|
||||
"""
|
||||
# for BetterThan test, we only need to compare the best challenger with the champion
|
||||
self._get_best_challenger()
|
||||
if self._best_challenger_trial is not None:
|
||||
assert self._best_challenger_trial.trial_id != self._champion_trial.trial_id
|
||||
# test whether a new champion is found and set the trial properties accordingly
|
||||
is_new_champion_found = self._better_than_champion_test(self._best_challenger_trial)
|
||||
if is_new_champion_found:
|
||||
self._set_champion(new_champion_trial=self._best_challenger_trial)
|
||||
|
||||
# performs _worse_than_champion_test, which is an optional component in ChaCha
|
||||
if self._remove_worse:
|
||||
to_stop = []
|
||||
for trial_to_test in self._trials:
|
||||
if trial_to_test.status != Trial.TERMINATED:
|
||||
worse_than_champion = self._worse_than_champion_test(
|
||||
self._champion_trial, trial_to_test, self.WARMSTART_NUM)
|
||||
if worse_than_champion:
|
||||
to_stop.append(trial_to_test)
|
||||
# we want to ensure there are at least #max_live_model_num of challengers remaining
|
||||
max_to_stop_num = len([t for t in self._trials if t.status != Trial.TERMINATED]
|
||||
) - self._max_live_model_num
|
||||
for i in range(min(max_to_stop_num, len(to_stop))):
|
||||
self.stop_trial(to_stop[i])
|
||||
|
||||
def _get_best_challenger(self):
|
||||
"""Get the 'best' (in terms of the champion_test_policy) challenger under consideration.
|
||||
"""
|
||||
if self._champion_test_policy is None:
|
||||
return
|
||||
if 'ucb' in self._champion_test_policy:
|
||||
test_attribute = 'loss_ucb'
|
||||
elif 'avg' in self._champion_test_policy:
|
||||
test_attribute = 'loss_avg'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
active_trials = [trial for trial in self._trials if
|
||||
(trial.status != Trial.TERMINATED
|
||||
and trial.trial_id != self._champion_trial.trial_id
|
||||
and trial.result is not None)]
|
||||
if active_trials:
|
||||
self._random_state.shuffle(active_trials)
|
||||
results = [trial.result.get_score(test_attribute) for trial in active_trials]
|
||||
best_index = np.argmin(results)
|
||||
self._best_challenger_trial = active_trials[best_index]
|
||||
|
||||
def _set_champion(self, new_champion_trial):
|
||||
"""Set the status of the existing trials once a new champion is found.
|
||||
"""
|
||||
assert new_champion_trial is not None
|
||||
is_init_update = False
|
||||
if self._champion_trial is None:
|
||||
is_init_update = True
|
||||
self.run_trial(new_champion_trial)
|
||||
# set the checked_under_current_champion status of the trials
|
||||
for trial in self._trials:
|
||||
if trial.trial_id == new_champion_trial.trial_id:
|
||||
trial.set_checked_under_current_champion(True)
|
||||
else:
|
||||
trial.set_checked_under_current_champion(False)
|
||||
self._champion_trial = new_champion_trial
|
||||
self._all_new_challengers_added = False
|
||||
logger.info('Set the champion as %s', self._champion_trial.trial_id)
|
||||
if not is_init_update:
|
||||
self._champion_update_times += 1
|
||||
# calling set_search_properties of searcher will trigger
|
||||
# new challenger generation. we do not do this for init champion
|
||||
# as this step is already done when first constructing the searcher
|
||||
self._searcher.set_search_properties(None, None,
|
||||
{self._searcher.CHAMPION_TRIAL_NAME: self._champion_trial}
|
||||
)
|
||||
else:
|
||||
self._champion_update_times = 0
|
||||
|
||||
def get_trials(self) -> list:
|
||||
"""Return the list of trials managed by this TrialRunner.
|
||||
"""
|
||||
return self._trials
|
||||
|
||||
def add_trial(self, new_trial):
|
||||
"""Add a new trial to this TrialRunner.
|
||||
|
||||
Trials may be added at any time.
|
||||
|
||||
Args:
|
||||
trial (Trial): Trial to queue.
|
||||
|
||||
NOTE:
|
||||
Only add the new trial when it does not exist (according to the trial_id, which is
|
||||
the signature of the trail) in self._trials.
|
||||
"""
|
||||
for trial in self._trials:
|
||||
if trial.trial_id == new_trial.trial_id:
|
||||
trial.set_checked_under_current_champion(True)
|
||||
return
|
||||
logger.info('adding trial at iter %s, %s %s', self._total_steps, new_trial.trial_id,
|
||||
len(self._trials))
|
||||
self._trials.append(new_trial)
|
||||
self._scheduler.on_trial_add(self, new_trial)
|
||||
|
||||
def stop_trial(self, trial):
|
||||
"""Stop a trial: set the status of a trial to be Trial.TERMINATED and perform
|
||||
other subsequent operations
|
||||
"""
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
else:
|
||||
logger.info('Terminating trial %s, with trial result %s',
|
||||
trial.trial_id, trial.result)
|
||||
trial.set_status(Trial.TERMINATED)
|
||||
# clean up model and result
|
||||
trial.clean_up_model()
|
||||
self._scheduler.on_trial_remove(self, trial)
|
||||
self._searcher.on_trial_complete(trial.trial_id)
|
||||
self._running_trials.remove(trial)
|
||||
|
||||
def pause_trial(self, trial):
|
||||
"""Pause a trial: set the status of a trial to be Trial.PAUSED and perform other
|
||||
subsequent operations
|
||||
"""
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
else:
|
||||
logger.info('Pausing trial %s, with trial loss_avg: %s, loss_cb: %s, loss_ucb: %s,\
|
||||
resource_lease: %s', trial.trial_id, trial.result.loss_avg,
|
||||
trial.result.loss_cb, trial.result.loss_avg + trial.result.loss_cb,
|
||||
trial.resource_lease)
|
||||
trial.set_status(Trial.PAUSED)
|
||||
# clean up model and result if no model persistence
|
||||
if self._no_model_persistence:
|
||||
trial.clean_up_model()
|
||||
self._running_trials.remove(trial)
|
||||
|
||||
def run_trial(self, trial):
|
||||
"""Run a trial: set the status of a trial to be Trial.RUNNING and perform other
|
||||
subsequent operations
|
||||
"""
|
||||
if trial.status in [Trial.ERROR, Trial.TERMINATED]:
|
||||
return
|
||||
else:
|
||||
trial.set_status(Trial.RUNNING)
|
||||
self._running_trials.add(trial)
|
||||
|
||||
def _better_than_champion_test(self, trial_to_test):
|
||||
"""Test whether there is a config in the existing trials that is better than
|
||||
the current champion config
|
||||
|
||||
Returns:
|
||||
A bool indicating whether a new champion is found
|
||||
"""
|
||||
if trial_to_test.result is not None and self._champion_trial.result is not None:
|
||||
if 'ucb' in self._champion_test_policy:
|
||||
return self._test_lcb_ucb(self._champion_trial, trial_to_test, self.WARMSTART_NUM)
|
||||
elif 'avg' in self._champion_test_policy:
|
||||
return self._test_avg_loss(self._champion_trial, trial_to_test, self.WARMSTART_NUM)
|
||||
elif 'martingale' in self._champion_test_policy:
|
||||
return self._test_martingale(self._champion_trial, trial_to_test)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _worse_than_champion_test(champion_trial, trial, warmstart_num=1) -> bool:
|
||||
"""Test whether the input trial is worse than the champion_trial
|
||||
"""
|
||||
if trial.result is not None and trial.result.resource_used >= warmstart_num:
|
||||
if trial.result.loss_lcb > champion_trial.result.loss_ucb:
|
||||
logger.info('=========trial %s is worse than champion %s=====',
|
||||
trial.trial_id, champion_trial.trial_id)
|
||||
logger.info('trial %s %s %s', trial.config, trial.result, trial.resource_lease)
|
||||
logger.info('trial loss_avg:%s, trial loss_cb %s', trial.result.loss_avg,
|
||||
trial.result.loss_cb)
|
||||
logger.info('champion loss_avg:%s, champion loss_cb %s', champion_trial.result.loss_avg,
|
||||
champion_trial.result.loss_cb)
|
||||
logger.info('champion %s', champion_trial.config)
|
||||
logger.info('trial loss_avg_recent:%s, trial loss_cb %s', trial.result.loss_avg_recent,
|
||||
trial.result.loss_cb)
|
||||
logger.info('champion loss_avg_recent:%s, champion loss_cb %s',
|
||||
champion_trial.result.loss_avg_recent, champion_trial.result.loss_cb)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _test_lcb_ucb(champion_trial, trial, warmstart_num=1) -> bool:
|
||||
"""Comare the challenger(i.e., trial)'s loss upper bound with
|
||||
champion_trial's loss lower bound - cb
|
||||
"""
|
||||
assert trial.trial_id != champion_trial.trial_id
|
||||
if trial.result.resource_used >= warmstart_num:
|
||||
if trial.result.loss_ucb < champion_trial.result.loss_lcb - champion_trial.result.loss_cb:
|
||||
logger.info('======new champion condition satisfied: using lcb vs ucb=====')
|
||||
logger.info('new champion trial %s %s %s',
|
||||
trial.trial_id, trial.result.resource_used, trial.resource_lease)
|
||||
logger.info('new champion trial loss_avg:%s, trial loss_cb %s',
|
||||
trial.result.loss_avg, trial.result.loss_cb)
|
||||
logger.info('old champion trial %s %s %s',
|
||||
champion_trial.trial_id, champion_trial.result.resource_used,
|
||||
champion_trial.resource_lease,)
|
||||
logger.info('old champion loss avg %s, loss cb %s',
|
||||
champion_trial.result.loss_avg,
|
||||
champion_trial.result.loss_cb)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _test_avg_loss(champion_trial, trial, warmstart_num=1) -> bool:
|
||||
"""Comare the challenger(i.e., trial)'s average loss with the
|
||||
champion_trial's average loss
|
||||
"""
|
||||
assert trial.trial_id != champion_trial.trial_id
|
||||
if trial.result.resource_used >= warmstart_num:
|
||||
if trial.result.loss_avg < champion_trial.result.loss_avg:
|
||||
logger.info('=====new champion condition satisfied using avg loss=====')
|
||||
logger.info('trial %s', trial.config)
|
||||
logger.info('trial loss_avg:%s, trial loss_cb %s',
|
||||
trial.result.loss_avg, trial.result.loss_cb)
|
||||
logger.info('champion loss_avg:%s, champion loss_cb %s',
|
||||
champion_trial.result.loss_avg, champion_trial.result.loss_cb)
|
||||
logger.info('champion %s', champion_trial.config)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _test_martingale(champion_trial, trial):
|
||||
"""Comare the challenger and champion using confidence sequence based
|
||||
test martingale
|
||||
|
||||
Not implementated yet
|
||||
"""
|
||||
NotImplementedError
|
|
@ -0,0 +1,2 @@
|
|||
from .trial_scheduler import TrialScheduler, FIFOScheduler
|
||||
from .online_scheduler import OnlineScheduler, OnlineSuccessiveDoublingScheduler, ChaChaScheduler
|
|
@ -0,0 +1,140 @@
|
|||
import numpy as np
|
||||
import logging
|
||||
from typing import Optional, Dict
|
||||
from flaml.scheduler import FIFOScheduler, TrialScheduler
|
||||
from flaml.tune import Trial
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OnlineScheduler(FIFOScheduler):
|
||||
"""Implementation of the OnlineFIFOSchedulers.
|
||||
|
||||
Methods:
|
||||
on_trial_result(trial_runner, trial, result)
|
||||
Report result and return a decision on the trial's status
|
||||
choose_trial_to_run(trial_runner)
|
||||
Decide which trial to run next
|
||||
"""
|
||||
def on_trial_result(self, trial_runner, trial: Trial, result: Dict):
|
||||
"""Report result and return a decision on the trial's status
|
||||
|
||||
Always keep a trial running (return status TrialScheduler.CONTINUE)
|
||||
"""
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def choose_trial_to_run(self, trial_runner) -> Trial:
|
||||
"""Decide which trial to run next
|
||||
|
||||
Trial prioritrization according to the status:
|
||||
PENDING (trials that have not been tried) > PAUSED (trials that have been ran)
|
||||
|
||||
For trials with the same status, it chooses the ones with smaller resource lease
|
||||
"""
|
||||
for trial in trial_runner.get_trials():
|
||||
if trial.status == Trial.PENDING:
|
||||
return trial
|
||||
min_paused_resource = np.inf
|
||||
min_paused_resource_trial = None
|
||||
for trial in trial_runner.get_trials():
|
||||
# if there is a tie, prefer the earlier added ones
|
||||
if trial.status == Trial.PAUSED and trial.resource_lease < min_paused_resource:
|
||||
min_paused_resource = trial.resource_lease
|
||||
min_paused_resource_trial = trial
|
||||
if min_paused_resource_trial is not None:
|
||||
return min_paused_resource_trial
|
||||
|
||||
|
||||
class OnlineSuccessiveDoublingScheduler(OnlineScheduler):
|
||||
"""Implementation of the OnlineSuccessiveDoublingScheduler.
|
||||
|
||||
Methods:
|
||||
on_trial_result(trial_runner, trial, result)
|
||||
Report result and return a decision on the trial's status
|
||||
choose_trial_to_run(trial_runner)
|
||||
Decide which trial to run next
|
||||
"""
|
||||
def __init__(self, increase_factor: float = 2.0):
|
||||
'''
|
||||
Args:
|
||||
increase_factor (float): a multiplicative factor used to increase resource lease.
|
||||
The default value is 2.0
|
||||
'''
|
||||
super().__init__()
|
||||
self._increase_factor = increase_factor
|
||||
|
||||
def on_trial_result(self, trial_runner, trial: Trial, result: Dict):
|
||||
"""Report result and return a decision on the trial's status
|
||||
|
||||
1. Returns TrialScheduler.CONTINUE (i.e., keep the trial running),
|
||||
if the resource consumed has not reached the current resource_lease.s
|
||||
2. otherwise double the current resource lease and return TrialScheduler.PAUSE
|
||||
"""
|
||||
if trial.result is None or trial.result.resource_used < trial.resource_lease:
|
||||
return TrialScheduler.CONTINUE
|
||||
else:
|
||||
trial.set_resource_lease(trial.resource_lease * self._increase_factor)
|
||||
logger.info('Doubled resource for trial %s, used: %s, current budget %s',
|
||||
trial.trial_id, trial.result.resource_used, trial.resource_lease)
|
||||
return TrialScheduler.PAUSE
|
||||
|
||||
|
||||
class ChaChaScheduler(OnlineSuccessiveDoublingScheduler):
|
||||
""" Keep the top performing learners running
|
||||
|
||||
Methods:
|
||||
on_trial_result(trial_runner, trial, result)
|
||||
Report result and return a decision on the trial's status
|
||||
choose_trial_to_run(trial_runner)
|
||||
Decide which trial to run next
|
||||
"""
|
||||
def __init__(self, increase_factor: float = 2.0, **kwargs):
|
||||
'''
|
||||
Args:
|
||||
increase_factor: a multiplicative factor used to increase resource lease.
|
||||
The default value is 2.0
|
||||
'''
|
||||
super().__init__(increase_factor)
|
||||
self._keep_champion = kwargs.get('keep_champion', True)
|
||||
self._keep_challenger_metric = kwargs.get('keep_challenger_metric', 'ucb')
|
||||
self._keep_challenger_ratio = kwargs.get('keep_challenger_ratio', 0.5)
|
||||
self._pause_old_froniter = kwargs.get('pause_old_froniter', False)
|
||||
logger.info('Using chacha scheduler with config %s', kwargs)
|
||||
|
||||
def on_trial_result(self, trial_runner, trial: Trial, result: Dict):
|
||||
"""Report result and return a decision on the trial's status
|
||||
|
||||
Make a decision according to: SuccessiveDoubling + champion check + performance check
|
||||
"""
|
||||
# Doubling scheduler makes a decision
|
||||
decision = super().on_trial_result(trial_runner, trial, result)
|
||||
# ***********Check whether the trial has been paused since a new champion is promoted**
|
||||
# NOTE: This check is not enabled by default. Just keeping it for experimentation purpose.
|
||||
## trial.is_checked_under_current_champion being False means the trial
|
||||
# has not been paused since the new champion is promoted. If so, we need to
|
||||
# tentatively pause it such that new trials can possiblly be taken into consideration
|
||||
# NOTE: This may need to be changed. We need to do this because we only add trials.
|
||||
# into the OnlineTrialRunner when there are avaialbe slots. Maybe we need to consider
|
||||
# adding max_running_trial number of trials once a new champion is promoted.
|
||||
if self._pause_old_froniter and not trial.is_checked_under_current_champion:
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
decision = TrialScheduler.PAUSE
|
||||
trial.set_checked_under_current_champion(True)
|
||||
logger.info('Tentitively set trial as paused')
|
||||
|
||||
# ****************Keep the champion always running******************
|
||||
if self._keep_champion and trial.trial_id == trial_runner.champion_trial.trial_id and \
|
||||
decision == TrialScheduler.PAUSE:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
# ****************Keep the trials with top performance always running******************
|
||||
if self._keep_challenger_ratio is not None:
|
||||
if decision == TrialScheduler.PAUSE:
|
||||
logger.debug('champion, %s', trial_runner.champion_trial.trial_id)
|
||||
# this can be inefficient when the # trials is large. TODO: need to improve efficiency.
|
||||
top_trials = trial_runner.get_top_running_trials(self._keep_challenger_ratio,
|
||||
self._keep_challenger_metric)
|
||||
logger.debug('top_learners: %s', top_trials)
|
||||
if trial in top_trials:
|
||||
logger.debug('top runner %s: set from PAUSE to CONTINUE', trial.trial_id)
|
||||
return TrialScheduler.CONTINUE
|
||||
return decision
|
|
@ -0,0 +1,157 @@
|
|||
'''
|
||||
Copyright 2020 The Ray Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This source file is adapted here because ray does not fully support Windows.
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
'''
|
||||
from typing import Dict, Optional
|
||||
|
||||
from flaml.tune import trial_runner
|
||||
from flaml.tune.result import DEFAULT_METRIC
|
||||
from flaml.tune.trial import Trial
|
||||
|
||||
|
||||
class TrialScheduler:
|
||||
"""Interface for implementing a Trial Scheduler class."""
|
||||
|
||||
CONTINUE = "CONTINUE" #: Status for continuing trial execution
|
||||
PAUSE = "PAUSE" #: Status for pausing trial execution
|
||||
STOP = "STOP" #: Status for stopping trial execution
|
||||
|
||||
_metric = None
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
def set_search_properties(self, metric: Optional[str],
|
||||
mode: Optional[str]) -> bool:
|
||||
"""Pass search properties to scheduler.
|
||||
This method acts as an alternative to instantiating schedulers
|
||||
that react to metrics with their own `metric` and `mode` parameters.
|
||||
Args:
|
||||
metric (str): Metric to optimize
|
||||
mode (str): One of ["min", "max"]. Direction to optimize.
|
||||
"""
|
||||
if self._metric and metric:
|
||||
return False
|
||||
if metric:
|
||||
self._metric = metric
|
||||
|
||||
if self._metric is None:
|
||||
# Per default, use anonymous metric
|
||||
self._metric = DEFAULT_METRIC
|
||||
|
||||
return True
|
||||
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Called when a new trial is added to the trial runner."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Notification for the error of trial.
|
||||
This will only be called when the trial is in the RUNNING state."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
"""Called on each intermediate result returned by a trial.
|
||||
At this point, the trial scheduler can make a decision by returning
|
||||
one of CONTINUE, PAUSE, and STOP. This will only be called when the
|
||||
trial is in the RUNNING state."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
"""Notification for the completion of trial.
|
||||
This will only be called when the trial is in the RUNNING state and
|
||||
either completes naturally or by manual termination."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
"""Called to remove trial.
|
||||
This is called when the trial is in PAUSED or PENDING state. Otherwise,
|
||||
call `on_trial_complete`."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
"""Called to choose a new trial to run.
|
||||
This should return one of the trials in trial_runner that is in
|
||||
the PENDING or PAUSED state. This function must be idempotent.
|
||||
If no trial is ready, return None."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def debug_string(self) -> str:
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, checkpoint_path: str):
|
||||
"""Save trial scheduler to a checkpoint"""
|
||||
raise NotImplementedError
|
||||
|
||||
def restore(self, checkpoint_path: str):
|
||||
"""Restore trial scheduler from checkpoint."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FIFOScheduler(TrialScheduler):
|
||||
"""Simple scheduler that just runs trials in submission order."""
|
||||
|
||||
def on_trial_add(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def on_trial_error(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def on_trial_result(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict) -> str:
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def on_trial_complete(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial, result: Dict):
|
||||
pass
|
||||
|
||||
def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner",
|
||||
trial: Trial):
|
||||
pass
|
||||
|
||||
def choose_trial_to_run(
|
||||
self, trial_runner: "trial_runner.TrialRunner") -> Optional[Trial]:
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING
|
||||
and trial_runner.has_resources_for_trial(trial)):
|
||||
return trial
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PAUSED
|
||||
and trial_runner.has_resources_for_trial(trial)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def debug_string(self) -> str:
|
||||
return "Using FIFO scheduling algorithm."
|
|
@ -1,2 +1,6 @@
|
|||
from .blendsearch import CFO, BlendSearch, BlendSearchTuner
|
||||
from .flow2 import FLOW2
|
||||
try:
|
||||
from .online_searcher import ChampionFrontierSearcher
|
||||
except ImportError:
|
||||
print('need to install vowpalwabbit to use ChampionFrontierSearcher')
|
||||
|
|
|
@ -51,8 +51,8 @@ class BlendSearch(Searcher):
|
|||
|
||||
Args:
|
||||
metric: A string of the metric name to optimize for.
|
||||
minimization or maximization.
|
||||
mode: A string in ['min', 'max'] to specify the objective as
|
||||
minimization or maximization.
|
||||
space: A dictionary to specify the search space.
|
||||
points_to_evaluate: Initial parameter suggestions to be run first.
|
||||
low_cost_partial_config: A dictionary from a subset of
|
||||
|
@ -107,6 +107,13 @@ class BlendSearch(Searcher):
|
|||
'''
|
||||
self._metric, self._mode = metric, mode
|
||||
init_config = low_cost_partial_config or {}
|
||||
if not init_config:
|
||||
logger.warning(
|
||||
"No low-cost init config given to the search algorithm."
|
||||
"For cost-frugal search, "
|
||||
"consider providing init values for cost-related hps via "
|
||||
"'init_config'."
|
||||
)
|
||||
self._points_to_evaluate = points_to_evaluate or []
|
||||
self._config_constraints = config_constraints
|
||||
self._metric_constraints = metric_constraints
|
||||
|
@ -202,6 +209,10 @@ class BlendSearch(Searcher):
|
|||
self._metric_constraint_satisfied = state._metric_constraint_satisfied
|
||||
self._metric_constraint_penalty = state._metric_constraint_penalty
|
||||
|
||||
@property
|
||||
def metric_target(self):
|
||||
return self._metric_target
|
||||
|
||||
def restore_from_dir(self, checkpoint_dir: str):
|
||||
super.restore_from_dir(checkpoint_dir)
|
||||
|
||||
|
|
|
@ -47,8 +47,8 @@ class FLOW2(Searcher):
|
|||
to the initial low-cost values.
|
||||
e.g. {'epochs': 1}
|
||||
metric: A string of the metric name to optimize for.
|
||||
minimization or maximization.
|
||||
mode: A string in ['min', 'max'] to specify the objective as
|
||||
minimization or maximization.
|
||||
cat_hp_cost: A dictionary from a subset of categorical dimensions
|
||||
to the relative cost of each choice.
|
||||
e.g.,
|
||||
|
@ -92,13 +92,6 @@ class FLOW2(Searcher):
|
|||
self.space = flatten_dict(self.space, prevent_delimiter=True)
|
||||
self._random = np.random.RandomState(seed)
|
||||
self._seed = seed
|
||||
if not init_config:
|
||||
logger.warning(
|
||||
"No init config given to FLOW2. Using random initial config."
|
||||
"For cost-frugal search, "
|
||||
"consider providing init values for cost-related hps via "
|
||||
"'init_config'."
|
||||
)
|
||||
self.init_config = init_config
|
||||
self.best_config = flatten_dict(init_config)
|
||||
self.cat_hp_cost = cat_hp_cost
|
||||
|
@ -508,6 +501,7 @@ class FLOW2(Searcher):
|
|||
1. same incumbent, increase resource
|
||||
2. same resource, move from the incumbent to a random direction
|
||||
3. same resource, move from the incumbent to the opposite direction
|
||||
#TODO: better decouple FLOW2 config suggestion and stepsize update
|
||||
'''
|
||||
self.trial_count_proposed += 1
|
||||
if self._num_complete4incumbent > 0 and self.cost_incumbent and \
|
||||
|
|
|
@ -0,0 +1,360 @@
|
|||
import numpy as np
|
||||
import logging
|
||||
import itertools
|
||||
from typing import Dict, Optional, List
|
||||
from flaml.tune import Categorical, Float, PolynomialExpansionSet
|
||||
from flaml.tune import Trial
|
||||
from flaml.onlineml import VowpalWabbitTrial
|
||||
from flaml.searcher import CFO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseSearcher:
|
||||
"""Implementation of the BaseSearcher
|
||||
|
||||
Methods:
|
||||
set_search_properties(metric, mode, config)
|
||||
next_trial()
|
||||
on_trial_result(trial_id, result)
|
||||
on_trial_complete()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def set_search_properties(self, metric: Optional[str] = None, mode: Optional[str] = None,
|
||||
config: Optional[Dict] = None):
|
||||
if metric:
|
||||
self._metric = metric
|
||||
if mode:
|
||||
assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
|
||||
self._mode = mode
|
||||
|
||||
def next_trial(self):
|
||||
NotImplementedError
|
||||
|
||||
def on_trial_result(self, trial_id: str, result: Dict):
|
||||
pass
|
||||
|
||||
def on_trial_complete(self, trial):
|
||||
pass
|
||||
|
||||
|
||||
class ChampionFrontierSearcher(BaseSearcher):
|
||||
"""The ChampionFrontierSearcher class
|
||||
|
||||
Methods:
|
||||
(metric, mode, config)
|
||||
Generate a list of new challengers, and add them to the _challenger_list
|
||||
next_trial()
|
||||
Pop a trial from the _challenger_list
|
||||
on_trial_result(trial_id, result)
|
||||
Doing nothing
|
||||
on_trial_complete()
|
||||
Doing nothing
|
||||
|
||||
NOTE:
|
||||
This class serves the role of ConfigOralce.
|
||||
Every time we create an online trial, we generate a searcher_trial_id.
|
||||
At the same time, we also record the trial_id of the VW trial.
|
||||
Note that the trial_id is a unique signature of the configuraiton.
|
||||
So if two VWTrials are associated with the same config, they will have the same trial_id
|
||||
(although not the same searcher_trial_id).
|
||||
searcher_trial_id will be used in suggest()
|
||||
"""
|
||||
# ****the following constants are used when generating new challengers in
|
||||
# the _query_config_oracle function
|
||||
# how many item to add when doing the expansion
|
||||
# (i.e. how many interaction items to add at each time)
|
||||
POLY_EXPANSION_ADDITION_NUM = 1
|
||||
# the order of polynomial expansions to add based on the given seed interactions
|
||||
EXPANSION_ORDER = 2
|
||||
# the number of new challengers with new numerical hyperparamter configs
|
||||
NUMERICAL_NUM = 2
|
||||
|
||||
# In order to use CFO, a loss name and loss values of configs are need
|
||||
# since CFO in fact only requires relative loss order of two configs to perform
|
||||
# the update, a pseudo loss can be used as long as the relative performance orders
|
||||
# of different configs are perserved. We set the loss of the init config to be
|
||||
# a large value (CFO_SEARCHER_LARGE_LOSS), and set the loss of the better config as
|
||||
# 0.95 of the previous best config's loss.
|
||||
# NOTE: this setting depends on the assumption that (and thus
|
||||
# _query_config_oracle) is only triggered when a better champion is found.
|
||||
CFO_SEARCHER_METRIC_NAME = 'pseudo_loss'
|
||||
CFO_SEARCHER_LARGE_LOSS = 1e6
|
||||
|
||||
# the random seed used in generating numerical hyperparamter configs (when CFO is not used)
|
||||
NUM_RANDOM_SEED = 111
|
||||
|
||||
CHAMPION_TRIAL_NAME = 'champion_trial'
|
||||
TRIAL_CLASS = VowpalWabbitTrial
|
||||
|
||||
def __init__(self,
|
||||
init_config: Dict,
|
||||
space: Optional[Dict] = None,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
random_seed: Optional[int] = 2345,
|
||||
online_trial_args: Optional[Dict] = {},
|
||||
nonpoly_searcher_name: Optional[str] = 'CFO'
|
||||
):
|
||||
'''Constructor
|
||||
|
||||
Args:
|
||||
init_config: dict
|
||||
space: dict
|
||||
metric: str
|
||||
mode: str
|
||||
random_seed: int
|
||||
online_trial_args: dict
|
||||
nonpoly_searcher_name: A string to specify the search algorithm
|
||||
for nonpoly hyperparameters
|
||||
'''
|
||||
self._init_config = init_config
|
||||
self._space = space
|
||||
self._seed = random_seed
|
||||
self._online_trial_args = online_trial_args
|
||||
self._nonpoly_searcher_name = nonpoly_searcher_name
|
||||
|
||||
self._random_state = np.random.RandomState(self._seed)
|
||||
self._searcher_for_nonpoly_hp = {}
|
||||
self._space_of_nonpoly_hp = {}
|
||||
# dicts to remember the mapping between searcher_trial_id and trial_id
|
||||
self._searcher_trialid_to_trialid = {} # key: searcher_trial_id, value: trial_id
|
||||
self._trialid_to_searcher_trial_id = {} # value: trial_id, key: searcher_trial_id
|
||||
self._challenger_list = []
|
||||
# initialize the search in set_search_properties
|
||||
self.set_search_properties(config={self.CHAMPION_TRIAL_NAME: None}, init_call=True)
|
||||
logger.debug('using random seed %s in config oracle', self._seed)
|
||||
|
||||
def set_search_properties(self, metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
config: Optional[Dict] = {},
|
||||
init_call: Optional[bool] = False):
|
||||
"""Construct search space with given config, and setup the search
|
||||
"""
|
||||
super().set_search_properties(metric, mode, config)
|
||||
# *********Use ConfigOralce (i.e, self._generate_new_space to generate list of new challengers)
|
||||
logger.info('champion trial %s', config)
|
||||
champion_trial = config.get(self.CHAMPION_TRIAL_NAME, None)
|
||||
if champion_trial is None:
|
||||
champion_trial = self._create_trial_from_config(self._init_config)
|
||||
# generate a new list of challenger trials
|
||||
new_challenger_list = self._query_config_oracle(champion_trial.config,
|
||||
champion_trial.trial_id,
|
||||
self._trialid_to_searcher_trial_id[champion_trial.trial_id])
|
||||
# add the newly generated challengers to existing challengers
|
||||
# there can be duplicates and we check duplicates when calling next_trial()
|
||||
self._challenger_list = self._challenger_list + new_challenger_list
|
||||
# add the champion as part of the new_challenger_list when called initially
|
||||
if init_call:
|
||||
self._challenger_list.append(champion_trial)
|
||||
logger.critical('Created challengers from champion %s', champion_trial.trial_id)
|
||||
logger.critical('New challenger size %s, %s', len(self._challenger_list),
|
||||
[t.trial_id for t in self._challenger_list])
|
||||
|
||||
def next_trial(self):
|
||||
"""Return a trial from the _challenger_list
|
||||
"""
|
||||
next_trial = None
|
||||
if self._challenger_list:
|
||||
next_trial = self._challenger_list.pop()
|
||||
return next_trial
|
||||
|
||||
def _create_trial_from_config(self, config, searcher_trial_id=None):
|
||||
if searcher_trial_id is None:
|
||||
searcher_trial_id = Trial.generate_id()
|
||||
trial = self.TRIAL_CLASS(config, **self._online_trial_args)
|
||||
self._searcher_trialid_to_trialid[searcher_trial_id] = trial.trial_id
|
||||
# only update the dict when the trial_id does not exist
|
||||
if trial.trial_id not in self._trialid_to_searcher_trial_id:
|
||||
self._trialid_to_searcher_trial_id[trial.trial_id] = searcher_trial_id
|
||||
return trial
|
||||
|
||||
def _query_config_oracle(self, seed_config, seed_config_trial_id,
|
||||
seed_config_searcher_trial_id=None) -> List[Trial]:
|
||||
"""Give the seed config, generate a list of new configs (which are supposed to include
|
||||
at least one config that has better performance than the input seed_config)
|
||||
"""
|
||||
# group the hyperparameters according to whether the configs of them are independent
|
||||
# with the other hyperparameters
|
||||
hyperparameter_config_groups = []
|
||||
searcher_trial_ids_groups = []
|
||||
nonpoly_config = {}
|
||||
for k, v in seed_config.items():
|
||||
config_domain = self._space[k]
|
||||
if isinstance(config_domain, PolynomialExpansionSet):
|
||||
# get candidate configs for hyperparameters of the PolynomialExpansionSet type
|
||||
partial_new_configs = self._generate_independent_hp_configs(k, v, config_domain)
|
||||
if partial_new_configs:
|
||||
hyperparameter_config_groups.append(partial_new_configs)
|
||||
# does not have searcher_trial_ids
|
||||
searcher_trial_ids_groups.append([])
|
||||
elif isinstance(config_domain, Float) or isinstance(config_domain, Categorical):
|
||||
# otherwise we need to deal with them in group
|
||||
nonpoly_config[k] = v
|
||||
if k not in self._space_of_nonpoly_hp:
|
||||
self._space_of_nonpoly_hp[k] = self._space[k]
|
||||
|
||||
# -----------generate partial new configs for non-PolynomialExpansionSet hyperparameters
|
||||
if nonpoly_config:
|
||||
new_searcher_trial_ids = []
|
||||
partial_new_nonpoly_configs = []
|
||||
if 'CFO' in self._nonpoly_searcher_name:
|
||||
if seed_config_trial_id not in self._searcher_for_nonpoly_hp:
|
||||
self._searcher_for_nonpoly_hp[seed_config_trial_id] = CFO(space=self._space_of_nonpoly_hp,
|
||||
points_to_evaluate=[nonpoly_config],
|
||||
metric=self.CFO_SEARCHER_METRIC_NAME,
|
||||
)
|
||||
# initialize the search in set_search_properties
|
||||
self._searcher_for_nonpoly_hp[seed_config_trial_id].set_search_properties(
|
||||
config={'metric_target': self.CFO_SEARCHER_LARGE_LOSS})
|
||||
# We need to call this for once, such that the seed config in points_to_evaluate will be called
|
||||
# to be tried
|
||||
self._searcher_for_nonpoly_hp[seed_config_trial_id].suggest(seed_config_searcher_trial_id)
|
||||
# assuming minimization
|
||||
if self._searcher_for_nonpoly_hp[seed_config_trial_id].metric_target is None:
|
||||
pseudo_loss = self.CFO_SEARCHER_LARGE_LOSS
|
||||
else:
|
||||
pseudo_loss = self._searcher_for_nonpoly_hp[seed_config_trial_id].metric_target * 0.95
|
||||
pseudo_result_to_report = {}
|
||||
for k, v in nonpoly_config.items():
|
||||
pseudo_result_to_report['config/' + str(k)] = v
|
||||
pseudo_result_to_report[self.CFO_SEARCHER_METRIC_NAME] = pseudo_loss
|
||||
pseudo_result_to_report['time_total_s'] = 1
|
||||
self._searcher_for_nonpoly_hp[seed_config_trial_id].on_trial_complete(seed_config_searcher_trial_id,
|
||||
result=pseudo_result_to_report)
|
||||
while len(partial_new_nonpoly_configs) < self.NUMERICAL_NUM:
|
||||
# suggest multiple times
|
||||
new_searcher_trial_id = Trial.generate_id()
|
||||
new_searcher_trial_ids.append(new_searcher_trial_id)
|
||||
suggestion = self._searcher_for_nonpoly_hp[seed_config_trial_id].suggest(new_searcher_trial_id)
|
||||
if suggestion is not None:
|
||||
partial_new_nonpoly_configs.append(suggestion)
|
||||
logger.info('partial_new_nonpoly_configs %s', partial_new_nonpoly_configs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if partial_new_nonpoly_configs:
|
||||
hyperparameter_config_groups.append(partial_new_nonpoly_configs)
|
||||
searcher_trial_ids_groups.append(new_searcher_trial_ids)
|
||||
# ----------- coordinate generation of new challengers in the case of multiple groups
|
||||
new_trials = []
|
||||
for i in range(len(hyperparameter_config_groups)):
|
||||
logger.info('hyperparameter_config_groups[i] %s %s',
|
||||
len(hyperparameter_config_groups[i]),
|
||||
hyperparameter_config_groups[i])
|
||||
for j, new_partial_config in enumerate(hyperparameter_config_groups[i]):
|
||||
new_seed_config = seed_config.copy()
|
||||
new_seed_config.update(new_partial_config)
|
||||
# For some groups of the hyperparameters, we may have already generated the
|
||||
# searcher_trial_id. In that case, we only need to retrieve the searcher_trial_id
|
||||
# instead of generating it again. So we do not generate searcher_trial_id and
|
||||
# instead set the searcher_trial_id to be None. When creating a trial from a config,
|
||||
# a searcher_trial_id will be generated if None is provided.
|
||||
# TODO: An alternative option is to generate a searcher_trial_id for each partial config
|
||||
if searcher_trial_ids_groups[i]:
|
||||
new_searcher_trial_id = searcher_trial_ids_groups[i][j]
|
||||
else:
|
||||
new_searcher_trial_id = None
|
||||
new_trial = self._create_trial_from_config(new_seed_config, new_searcher_trial_id)
|
||||
new_trials.append(new_trial)
|
||||
logger.info('new_configs %s', [t.trial_id for t in new_trials])
|
||||
return new_trials
|
||||
|
||||
def _generate_independent_hp_configs(self, hp_name, current_config_value, config_domain) -> List:
|
||||
if isinstance(config_domain, PolynomialExpansionSet):
|
||||
seed_interactions = list(current_config_value) + list(config_domain.init_monomials)
|
||||
logger.critical('Seed namespaces (singletons and interactions): %s', seed_interactions)
|
||||
logger.info('current_config_value %s %s', current_config_value, seed_interactions)
|
||||
configs = self._generate_poly_expansion_sets(seed_interactions,
|
||||
self.EXPANSION_ORDER,
|
||||
config_domain.allow_self_inter,
|
||||
config_domain.highest_poly_order,
|
||||
self.POLY_EXPANSION_ADDITION_NUM,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
configs_w_key = [{hp_name: hp_config} for hp_config in configs]
|
||||
return configs_w_key
|
||||
|
||||
def _generate_poly_expansion_sets(self, seed_interactions, order, allow_self_inter,
|
||||
highest_poly_order, interaction_num_to_add):
|
||||
champion_all_combinations = self._generate_all_comb(seed_interactions, order, allow_self_inter, highest_poly_order)
|
||||
space = sorted(list(itertools.combinations(
|
||||
champion_all_combinations, interaction_num_to_add)))
|
||||
self._random_state.shuffle(space)
|
||||
candidate_configs = [set(seed_interactions) | set(item) for item in space]
|
||||
final_candidate_configs = []
|
||||
for c in candidate_configs:
|
||||
new_c = set([e for e in c if len(e) > 1])
|
||||
final_candidate_configs.append(new_c)
|
||||
return final_candidate_configs
|
||||
|
||||
@staticmethod
|
||||
def _generate_all_comb(seed_interactions: list, seed_interaction_order: int,
|
||||
allow_self_inter: Optional[bool] = False,
|
||||
highest_poly_order: Optional[int] = None):
|
||||
"""Generate new interactions by doing up to seed_interaction_order on the seed_interactions
|
||||
|
||||
Args:
|
||||
seed_interactions (List[str]): the see config which is a list of interactions string
|
||||
(including the singletons)
|
||||
seed_interaction_order (int): the maxmum order of interactions to perform on the seed_config
|
||||
allow_self_inter (bool): whether self-interaction is allowed
|
||||
e.g. if set False, 'aab' will be considered as 'ab', i.e. duplicates in the interaction
|
||||
string are removed.
|
||||
highest_poly_order (int): the highest polynomial order allowed for the resulting interaction.
|
||||
e.g. if set 3, the interaction 'abcd' will be excluded.
|
||||
"""
|
||||
|
||||
def get_interactions(list1, list2):
|
||||
"""Get combinatorial list of tuples
|
||||
"""
|
||||
new_list = []
|
||||
for i in list1:
|
||||
for j in list2:
|
||||
# each interaction is sorted. E.g. after sorting
|
||||
# 'abc' 'cba' 'bca' are all 'abc'
|
||||
# this is done to ensure we can use the config as the signature
|
||||
# of the trial, i.e., trial id.
|
||||
new_interaction = ''.join(sorted(i + j))
|
||||
if new_interaction not in new_list:
|
||||
new_list.append(new_interaction)
|
||||
return new_list
|
||||
|
||||
def strip_self_inter(s):
|
||||
"""Remove duplicates in an interaction string
|
||||
"""
|
||||
if len(s) == len(set(s)):
|
||||
return s
|
||||
else:
|
||||
# return ''.join(sorted(set(s)))
|
||||
new_s = ''
|
||||
char_list = []
|
||||
for i in s:
|
||||
if i not in char_list:
|
||||
char_list.append(i)
|
||||
new_s += i
|
||||
return new_s
|
||||
|
||||
interactions = seed_interactions.copy()
|
||||
all_interactions = []
|
||||
while seed_interaction_order > 1:
|
||||
interactions = get_interactions(interactions, seed_interactions)
|
||||
seed_interaction_order -= 1
|
||||
all_interactions += interactions
|
||||
if not allow_self_inter:
|
||||
all_interactions_no_self_inter = []
|
||||
for s in all_interactions:
|
||||
s_no_inter = strip_self_inter(s)
|
||||
if len(s_no_inter) > 1 and s_no_inter not in all_interactions_no_self_inter:
|
||||
all_interactions_no_self_inter.append(s_no_inter)
|
||||
all_interactions = all_interactions_no_self_inter
|
||||
if highest_poly_order is not None:
|
||||
all_interactions = [c for c in all_interactions if len(c) <= highest_poly_order]
|
||||
logger.info('all_combinations %s', all_interactions)
|
||||
return all_interactions
|
|
@ -5,3 +5,6 @@ except ImportError:
|
|||
from .sample import (uniform, quniform, choice, randint, qrandint, randn,
|
||||
qrandn, loguniform, qloguniform)
|
||||
from .tune import run, report
|
||||
from .sample import polynomial_expansion_set
|
||||
from .sample import PolynomialExpansionSet, Categorical, Float
|
||||
from .trial import Trial
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"Registrations": [
|
||||
{
|
||||
"Component": {
|
||||
"Type": "pip",
|
||||
"pip": {"Name": "ray[tune]", "Version": "1.2.0" }
|
||||
},
|
||||
"DevelopmentDependency": false
|
||||
},
|
||||
]
|
||||
}
|
|
@ -0,0 +1,148 @@
|
|||
'''
|
||||
Copyright 2020 The Ray Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
This source file is adapted here because ray does not fully support Windows.
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
'''
|
||||
import os
|
||||
|
||||
# yapf: disable
|
||||
# __sphinx_doc_begin__
|
||||
# (Optional/Auto-filled) training is terminated. Filled only if not provided.
|
||||
DONE = "done"
|
||||
|
||||
# (Optional) Enum for user controlled checkpoint
|
||||
SHOULD_CHECKPOINT = "should_checkpoint"
|
||||
|
||||
# (Auto-filled) The hostname of the machine hosting the training process.
|
||||
HOSTNAME = "hostname"
|
||||
|
||||
# (Auto-filled) The auto-assigned id of the trial.
|
||||
TRIAL_ID = "trial_id"
|
||||
|
||||
# (Auto-filled) The auto-assigned id of the trial.
|
||||
EXPERIMENT_TAG = "experiment_tag"
|
||||
|
||||
# (Auto-filled) The node ip of the machine hosting the training process.
|
||||
NODE_IP = "node_ip"
|
||||
|
||||
# (Auto-filled) The pid of the training process.
|
||||
PID = "pid"
|
||||
|
||||
# (Optional) Default (anonymous) metric when using tune.report(x)
|
||||
DEFAULT_METRIC = "_metric"
|
||||
|
||||
# (Optional) Mean reward for current training iteration
|
||||
EPISODE_REWARD_MEAN = "episode_reward_mean"
|
||||
|
||||
# (Optional) Mean loss for training iteration
|
||||
MEAN_LOSS = "mean_loss"
|
||||
|
||||
# (Optional) Mean loss for training iteration
|
||||
NEG_MEAN_LOSS = "neg_mean_loss"
|
||||
|
||||
# (Optional) Mean accuracy for training iteration
|
||||
MEAN_ACCURACY = "mean_accuracy"
|
||||
|
||||
# Number of episodes in this iteration.
|
||||
EPISODES_THIS_ITER = "episodes_this_iter"
|
||||
|
||||
# (Optional/Auto-filled) Accumulated number of episodes for this trial.
|
||||
EPISODES_TOTAL = "episodes_total"
|
||||
|
||||
# Number of timesteps in this iteration.
|
||||
TIMESTEPS_THIS_ITER = "timesteps_this_iter"
|
||||
|
||||
# (Auto-filled) Accumulated number of timesteps for this entire trial.
|
||||
TIMESTEPS_TOTAL = "timesteps_total"
|
||||
|
||||
# (Auto-filled) Time in seconds this iteration took to run.
|
||||
# This may be overridden to override the system-computed time difference.
|
||||
TIME_THIS_ITER_S = "time_this_iter_s"
|
||||
|
||||
# (Auto-filled) Accumulated time in seconds for this entire trial.
|
||||
TIME_TOTAL_S = "time_total_s"
|
||||
|
||||
# (Auto-filled) The index of this training iteration.
|
||||
TRAINING_ITERATION = "training_iteration"
|
||||
# __sphinx_doc_end__
|
||||
# yapf: enable
|
||||
|
||||
DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID)
|
||||
|
||||
DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL,
|
||||
MEAN_ACCURACY, MEAN_LOSS)
|
||||
|
||||
# Make sure this doesn't regress
|
||||
AUTO_RESULT_KEYS = (
|
||||
TRAINING_ITERATION,
|
||||
TIME_TOTAL_S,
|
||||
EPISODES_TOTAL,
|
||||
TIMESTEPS_TOTAL,
|
||||
NODE_IP,
|
||||
HOSTNAME,
|
||||
PID,
|
||||
TIME_TOTAL_S,
|
||||
TIME_THIS_ITER_S,
|
||||
"timestamp",
|
||||
"experiment_id",
|
||||
"date",
|
||||
"time_since_restore",
|
||||
"iterations_since_restore",
|
||||
"timesteps_since_restore",
|
||||
"config",
|
||||
)
|
||||
|
||||
# __duplicate__ is a magic keyword used internally to
|
||||
# avoid double-logging results when using the Function API.
|
||||
RESULT_DUPLICATE = "__duplicate__"
|
||||
|
||||
# __trial_info__ is a magic keyword used internally to pass trial_info
|
||||
# to the Trainable via the constructor.
|
||||
TRIAL_INFO = "__trial_info__"
|
||||
|
||||
# __stdout_file__/__stderr_file__ are magic keywords used internally
|
||||
# to pass log file locations to the Trainable via the constructor.
|
||||
STDOUT_FILE = "__stdout_file__"
|
||||
STDERR_FILE = "__stderr_file__"
|
||||
|
||||
# Where Tune writes result files by default
|
||||
DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR")
|
||||
or os.environ.get("TUNE_RESULT_DIR")
|
||||
or os.path.expanduser("~/ray_results"))
|
||||
|
||||
# Meta file about status under each experiment directory, can be
|
||||
# parsed by automlboard if exists.
|
||||
JOB_META_FILE = "job_status.json"
|
||||
|
||||
# Meta file about status under each trial directory, can be parsed
|
||||
# by automlboard if exists.
|
||||
EXPR_META_FILE = "trial_status.json"
|
||||
|
||||
# File that stores parameters of the trial.
|
||||
EXPR_PARAM_FILE = "params.json"
|
||||
|
||||
# Pickle File that stores parameters of the trial.
|
||||
EXPR_PARAM_PICKLE_FILE = "params.pkl"
|
||||
|
||||
# File that stores the progress of the trial.
|
||||
EXPR_PROGRESS_FILE = "progress.csv"
|
||||
|
||||
# File that stores results of the trial.
|
||||
EXPR_RESULT_FILE = "result.json"
|
||||
|
||||
# Config prefix when using Analysis.
|
||||
CONFIG_PREFIX = "config/"
|
|
@ -414,6 +414,31 @@ class Quantized(Sampler):
|
|||
return list(quantized)
|
||||
|
||||
|
||||
class PolynomialExpansionSet:
|
||||
|
||||
def __init__(self, init_monomials: set = (), highest_poly_order: int = None,
|
||||
allow_self_inter: bool = False):
|
||||
self._init_monomials = init_monomials
|
||||
self._highest_poly_order = highest_poly_order if \
|
||||
highest_poly_order is not None else len(self._init_monomials)
|
||||
self._allow_self_inter = allow_self_inter
|
||||
|
||||
@property
|
||||
def init_monomials(self):
|
||||
return self._init_monomials
|
||||
|
||||
@property
|
||||
def highest_poly_order(self):
|
||||
return self._highest_poly_order
|
||||
|
||||
@property
|
||||
def allow_self_inter(self):
|
||||
return self._allow_self_inter
|
||||
|
||||
def __str__(self):
|
||||
return "PolynomialExpansionSet"
|
||||
|
||||
|
||||
# TODO (krfricke): Remove tune.function
|
||||
def function(func):
|
||||
logger.warning(
|
||||
|
@ -535,3 +560,9 @@ def qrandn(mean: float, sd: float, q: float):
|
|||
integer increment of this value.
|
||||
"""
|
||||
return Float(None, None).normal(mean, sd).quantized(q)
|
||||
|
||||
|
||||
def polynomial_expansion_set(init_monomials: set, highest_poly_order: int = None,
|
||||
allow_self_inter: bool = False):
|
||||
|
||||
return PolynomialExpansionSet(init_monomials, highest_poly_order, allow_self_inter)
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "0.4.2"
|
||||
__version__ = "0.5.0"
|
||||
|
|
File diff suppressed because one or more lines are too long
6
setup.py
6
setup.py
|
@ -40,6 +40,7 @@ setuptools.setup(
|
|||
"jupyter",
|
||||
"matplotlib==3.2.0",
|
||||
"rgf-python",
|
||||
"vowpalwabbit",
|
||||
],
|
||||
"test": [
|
||||
"flake8>=3.8.4",
|
||||
|
@ -48,6 +49,8 @@ setuptools.setup(
|
|||
"xgboost<1.3",
|
||||
"rgf-python",
|
||||
"optuna==2.3.0",
|
||||
"vowpalwabbit",
|
||||
"openml",
|
||||
],
|
||||
"blendsearch": [
|
||||
"optuna==2.3.0"
|
||||
|
@ -62,6 +65,9 @@ setuptools.setup(
|
|||
"nni": [
|
||||
"nni",
|
||||
],
|
||||
"vw": [
|
||||
"vowpalwabbit",
|
||||
]
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
|
|
|
@ -0,0 +1,372 @@
|
|||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error, mean_absolute_error
|
||||
import time
|
||||
import logging
|
||||
from flaml.tune import loguniform, polynomial_expansion_set
|
||||
from vowpalwabbit import pyvw
|
||||
from flaml import AutoVW
|
||||
import string
|
||||
import os
|
||||
import openml
|
||||
VW_DS_DIR = 'test/data/'
|
||||
NS_LIST = list(string.ascii_lowercase) + list(string.ascii_uppercase)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def oml_to_vw_w_grouping(X, y, ds_dir, fname, orginal_dim, group_num,
|
||||
grouping_method='sequential'):
|
||||
# split all_indexes into # group_num of groups
|
||||
max_size_per_group = int(np.ceil(orginal_dim / float(group_num)))
|
||||
# sequential grouping
|
||||
if grouping_method == 'sequential':
|
||||
group_indexes = [] # lists of lists
|
||||
for i in range(group_num):
|
||||
indexes = [ind for ind in range(i * max_size_per_group,
|
||||
min((i + 1) * max_size_per_group, orginal_dim))]
|
||||
if len(indexes) > 0:
|
||||
group_indexes.append(indexes)
|
||||
print(group_indexes)
|
||||
else:
|
||||
NotImplementedError
|
||||
if group_indexes:
|
||||
if not os.path.exists(ds_dir):
|
||||
os.makedirs(ds_dir)
|
||||
with open(os.path.join(ds_dir, fname), 'w') as f:
|
||||
if isinstance(X, pd.DataFrame):
|
||||
raise NotImplementedError
|
||||
elif isinstance(X, np.ndarray):
|
||||
for i in range(len(X)):
|
||||
NS_content = []
|
||||
for zz in range(len(group_indexes)):
|
||||
ns_features = ' '.join('{}:{:.6f}'.format(ind, X[i][ind]
|
||||
) for ind in group_indexes[zz])
|
||||
NS_content.append(ns_features)
|
||||
ns_line = '{} |{}'.format(str(y[i]), '|'.join(
|
||||
'{} {}'.format(NS_LIST[j], NS_content[j]
|
||||
) for j in range(len(group_indexes))))
|
||||
f.write(ns_line)
|
||||
f.write('\n')
|
||||
elif isinstance(X, scipy.sparse.csr_matrix):
|
||||
print('NotImplementedError for sparse data')
|
||||
NotImplementedError
|
||||
|
||||
|
||||
def save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression):
|
||||
""" convert openml dataset to vw example and save to file
|
||||
"""
|
||||
print('is_regression', is_regression)
|
||||
if is_regression:
|
||||
fname = 'ds_{}_{}_{}.vw'.format(did, max_ns_num, 0)
|
||||
print('dataset size', X.shape[0], X.shape[1])
|
||||
print('saving data', did, ds_dir, fname)
|
||||
dim = X.shape[1]
|
||||
oml_to_vw_w_grouping(X, y, ds_dir, fname, dim, group_num=max_ns_num)
|
||||
else:
|
||||
NotImplementedError
|
||||
|
||||
|
||||
def shuffle_data(X, y, seed):
|
||||
try:
|
||||
n = len(X)
|
||||
except ValueError:
|
||||
n = X.getnnz()
|
||||
|
||||
perm = np.random.RandomState(seed=seed).permutation(n)
|
||||
X_shuf = X[perm, :]
|
||||
y_shuf = y[perm]
|
||||
return X_shuf, y_shuf
|
||||
|
||||
|
||||
def get_oml_to_vw(did, max_ns_num, ds_dir=VW_DS_DIR):
|
||||
success = False
|
||||
print('-----getting oml dataset-------', did)
|
||||
ds = openml.datasets.get_dataset(did)
|
||||
target_attribute = ds.default_target_attribute
|
||||
# if target_attribute is None and did in OML_target_attribute_dict:
|
||||
# target_attribute = OML_target_attribute_dict[did]
|
||||
|
||||
print('target=ds.default_target_attribute', target_attribute)
|
||||
data = ds.get_data(target=target_attribute, dataset_format='array')
|
||||
X, y = data[0], data[1] # return X: pd DataFrame, y: pd series
|
||||
import scipy
|
||||
if scipy.sparse.issparse(X):
|
||||
X = scipy.sparse.csr_matrix.toarray(X)
|
||||
print('is sparse matrix')
|
||||
if data and isinstance(X, np.ndarray):
|
||||
print('-----converting oml to vw and and saving oml dataset-------')
|
||||
save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression=True)
|
||||
success = True
|
||||
else:
|
||||
print('---failed to convert/save oml dataset to vw!!!----')
|
||||
try:
|
||||
X, y = data[0], data[1] # return X: pd DataFrame, y: pd series
|
||||
if data and isinstance(X, np.ndarray):
|
||||
print('-----converting oml to vw and and saving oml dataset-------')
|
||||
save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression=True)
|
||||
success = True
|
||||
else:
|
||||
print('---failed to convert/save oml dataset to vw!!!----')
|
||||
except ValueError:
|
||||
print('-------------failed to get oml dataset!!!', did)
|
||||
return success
|
||||
|
||||
|
||||
def load_vw_dataset(did, ds_dir, is_regression, max_ns_num):
|
||||
import os
|
||||
if is_regression:
|
||||
# the second field specifies the largest number of namespaces using.
|
||||
fname = 'ds_{}_{}_{}.vw'.format(did, max_ns_num, 0)
|
||||
vw_dataset_file = os.path.join(ds_dir, fname)
|
||||
# if file does not exist, generate and save the datasets
|
||||
if not os.path.exists(vw_dataset_file) or os.stat(vw_dataset_file).st_size < 1000:
|
||||
get_oml_to_vw(did, max_ns_num)
|
||||
print(ds_dir, vw_dataset_file)
|
||||
if not os.path.exists(ds_dir):
|
||||
os.makedirs(ds_dir)
|
||||
with open(os.path.join(ds_dir, fname), 'r') as f:
|
||||
vw_content = f.read().splitlines()
|
||||
print(type(vw_content), len(vw_content))
|
||||
return vw_content
|
||||
|
||||
|
||||
def get_data(iter_num=None, dataset_id=None, vw_format=True,
|
||||
max_ns_num=10, shuffle=False, use_log=True, dataset_type='regression'):
|
||||
logging.info('generating data')
|
||||
LOG_TRANSFORMATION_THRESHOLD = 100
|
||||
# get data from simulation
|
||||
import random
|
||||
vw_examples = None
|
||||
data_id = int(dataset_id)
|
||||
# loading oml dataset
|
||||
# data = OpenML2VWData(data_id, max_ns_num, dataset_type)
|
||||
# Y = data.Y
|
||||
if vw_format:
|
||||
# vw_examples = data.vw_examples
|
||||
vw_examples = load_vw_dataset(did=data_id, ds_dir=VW_DS_DIR, is_regression=True,
|
||||
max_ns_num=max_ns_num)
|
||||
Y = []
|
||||
for i, e in enumerate(vw_examples):
|
||||
Y.append(float(e.split('|')[0]))
|
||||
logger.debug('first data %s', vw_examples[0])
|
||||
# do data shuffling or log transformation for oml data when needed
|
||||
if shuffle:
|
||||
random.seed(54321)
|
||||
random.shuffle(vw_examples)
|
||||
|
||||
# do log transformation
|
||||
unique_y = set(Y)
|
||||
min_y = min(unique_y)
|
||||
max_y = max(unique_y)
|
||||
if use_log and max((max_y - min_y), max_y) >= LOG_TRANSFORMATION_THRESHOLD:
|
||||
log_vw_examples = []
|
||||
for v in vw_examples:
|
||||
org_y = v.split('|')[0]
|
||||
y = float(v.split('|')[0])
|
||||
# shift y to ensure all y are positive
|
||||
if min_y <= 0:
|
||||
y = y + abs(min_y) + 1
|
||||
log_y = np.log(y)
|
||||
log_vw = v.replace(org_y + '|', str(log_y) + ' |')
|
||||
log_vw_examples.append(log_vw)
|
||||
logger.info('log_vw_examples %s', log_vw_examples[0:2])
|
||||
if log_vw_examples:
|
||||
return log_vw_examples
|
||||
return vw_examples, Y
|
||||
|
||||
|
||||
class VowpalWabbitNamesspaceTuningProblem:
|
||||
|
||||
def __init__(self, max_iter_num, dataset_id, ns_num, **kwargs):
|
||||
use_log = kwargs.get('use_log', True),
|
||||
shuffle = kwargs.get('shuffle', False)
|
||||
vw_format = kwargs.get('vw_format', True)
|
||||
print('dataset_id', dataset_id)
|
||||
self.vw_examples, self.Y = get_data(max_iter_num, dataset_id=dataset_id,
|
||||
vw_format=vw_format, max_ns_num=ns_num,
|
||||
shuffle=shuffle, use_log=use_log
|
||||
)
|
||||
self.max_iter_num = min(max_iter_num, len(self.Y))
|
||||
self._problem_info = {'max_iter_num': self.max_iter_num,
|
||||
'dataset_id': dataset_id,
|
||||
'ns_num': ns_num,
|
||||
}
|
||||
self._problem_info.update(kwargs)
|
||||
self._fixed_hp_config = kwargs.get('fixed_hp_config', {})
|
||||
self.namespace_feature_dim = AutoVW.get_ns_feature_dim_from_vw_example(self.vw_examples[0])
|
||||
self._raw_namespaces = list(self.namespace_feature_dim.keys())
|
||||
self._setup_search()
|
||||
|
||||
def _setup_search(self):
|
||||
self._search_space = self._fixed_hp_config.copy()
|
||||
self._init_config = self._fixed_hp_config.copy()
|
||||
search_space = {'interactions':
|
||||
polynomial_expansion_set(
|
||||
init_monomials=set(self._raw_namespaces),
|
||||
highest_poly_order=len(self._raw_namespaces),
|
||||
allow_self_inter=False),
|
||||
}
|
||||
init_config = {'interactions': set()}
|
||||
self._search_space.update(search_space)
|
||||
self._init_config.update(init_config)
|
||||
logger.info('search space %s %s %s', self._search_space, self._init_config, self._fixed_hp_config)
|
||||
|
||||
@property
|
||||
def init_config(self):
|
||||
return self._init_config
|
||||
|
||||
@property
|
||||
def search_space(self):
|
||||
return self._search_space
|
||||
|
||||
|
||||
class VowpalWabbitNamesspaceLRTuningProblem(VowpalWabbitNamesspaceTuningProblem):
|
||||
|
||||
def __init__(self, max_iter_num, dataset_id, ns_num, **kwargs):
|
||||
super().__init__(max_iter_num, dataset_id, ns_num, **kwargs)
|
||||
self._setup_search()
|
||||
|
||||
def _setup_search(self):
|
||||
self._search_space = self._fixed_hp_config.copy()
|
||||
self._init_config = self._fixed_hp_config.copy()
|
||||
search_space = {'interactions':
|
||||
polynomial_expansion_set(
|
||||
init_monomials=set(self._raw_namespaces),
|
||||
highest_poly_order=len(self._raw_namespaces),
|
||||
allow_self_inter=False),
|
||||
'learning_rate': loguniform(lower=2e-10, upper=1.0)
|
||||
}
|
||||
init_config = {'interactions': set(), 'learning_rate': 0.5}
|
||||
self._search_space.update(search_space)
|
||||
self._init_config.update(init_config)
|
||||
logger.info('search space %s %s %s', self._search_space, self._init_config, self._fixed_hp_config)
|
||||
|
||||
|
||||
def get_y_from_vw_example(vw_example):
|
||||
""" get y from a vw_example. this works for regression dataset
|
||||
"""
|
||||
return float(vw_example.split('|')[0])
|
||||
|
||||
|
||||
def get_loss(y_pred, y_true, loss_func='squared'):
|
||||
if 'squared' in loss_func:
|
||||
loss = mean_squared_error([y_pred], [y_true])
|
||||
elif 'absolute' in loss_func:
|
||||
loss = mean_absolute_error([y_pred], [y_true])
|
||||
else:
|
||||
loss = None
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
|
||||
def online_learning_loop(iter_num, vw_examples, vw_alg, loss_func, method_name=''):
|
||||
"""Implements the online learning loop.
|
||||
Args:
|
||||
iter_num (int): The total number of iterations
|
||||
vw_examples (list): A list of vw examples
|
||||
alg (alg instance): An algorithm instance has the following functions:
|
||||
- alg.learn(example)
|
||||
- alg.predict(example)
|
||||
loss_func (str): loss function
|
||||
Outputs:
|
||||
cumulative_loss_list (list): the list of cumulative loss from each iteration.
|
||||
It is returned for the convenience of visualization.
|
||||
"""
|
||||
print('rerunning exp....', len(vw_examples), iter_num)
|
||||
loss_list = []
|
||||
y_predict_list = []
|
||||
for i in range(iter_num):
|
||||
vw_x = vw_examples[i]
|
||||
y_true = get_y_from_vw_example(vw_x)
|
||||
# predict step
|
||||
y_pred = vw_alg.predict(vw_x)
|
||||
# learn step
|
||||
vw_alg.learn(vw_x)
|
||||
# calculate one step loss
|
||||
loss = get_loss(y_pred, y_true, loss_func)
|
||||
loss_list.append(loss)
|
||||
y_predict_list.append([y_pred, y_true])
|
||||
|
||||
return loss_list
|
||||
|
||||
|
||||
def get_vw_tuning_problem(tuning_hp='NamesapceInteraction'):
|
||||
online_vw_exp_setting = {"max_live_model_num": 5,
|
||||
"fixed_hp_config": {'alg': 'supervised', 'loss_function': 'squared'},
|
||||
"ns_num": 10,
|
||||
"max_iter_num": 10000,
|
||||
}
|
||||
|
||||
# construct openml problem setting based on basic experiment setting
|
||||
vw_oml_problem_args = {"max_iter_num": online_vw_exp_setting['max_iter_num'],
|
||||
"dataset_id": '42183',
|
||||
"ns_num": online_vw_exp_setting['ns_num'],
|
||||
"fixed_hp_config": online_vw_exp_setting['fixed_hp_config'],
|
||||
}
|
||||
if tuning_hp == 'NamesapceInteraction':
|
||||
vw_online_aml_problem = VowpalWabbitNamesspaceTuningProblem(**vw_oml_problem_args)
|
||||
elif tuning_hp == 'NamesapceInteraction+LearningRate':
|
||||
vw_online_aml_problem = VowpalWabbitNamesspaceLRTuningProblem(**vw_oml_problem_args)
|
||||
else:
|
||||
NotImplementedError
|
||||
|
||||
return vw_oml_problem_args, vw_online_aml_problem
|
||||
|
||||
|
||||
class TestAutoVW(unittest.TestCase):
|
||||
|
||||
def test_vw_oml_problem_and_vanilla_vw(self):
|
||||
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem()
|
||||
vanilla_vw = pyvw.vw(**vw_oml_problem_args["fixed_hp_config"])
|
||||
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
|
||||
vw_online_aml_problem.vw_examples,
|
||||
vanilla_vw,
|
||||
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
|
||||
)
|
||||
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
|
||||
|
||||
def test_supervised_vw_tune_namespace(self):
|
||||
# basic experiment setting
|
||||
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem()
|
||||
autovw = AutoVW(max_live_model_num=5,
|
||||
search_space=vw_online_aml_problem.search_space,
|
||||
init_config=vw_online_aml_problem.init_config,
|
||||
min_resource_lease='auto',
|
||||
random_seed=2345)
|
||||
|
||||
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
|
||||
vw_online_aml_problem.vw_examples,
|
||||
autovw,
|
||||
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
|
||||
)
|
||||
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
|
||||
|
||||
def test_supervised_vw_tune_namespace_learningrate(self):
|
||||
# basic experiment setting
|
||||
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem(tuning_hp='NamesapceInteraction+LearningRate')
|
||||
autovw = AutoVW(max_live_model_num=5,
|
||||
search_space=vw_online_aml_problem.search_space,
|
||||
init_config=vw_online_aml_problem.init_config,
|
||||
min_resource_lease='auto',
|
||||
random_seed=2345)
|
||||
|
||||
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
|
||||
vw_online_aml_problem.vw_examples,
|
||||
autovw,
|
||||
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
|
||||
)
|
||||
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
|
||||
|
||||
def test_bandit_vw_tune_namespace(self):
|
||||
pass
|
||||
|
||||
def test_bandit_vw_tune_namespace_learningrate(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue