Support parallel and add random search (#167)

* non hashable value out of signature

* parallel trials

* add random in _search_parallel

* fix bug in retraining

* check memory constraint before training

* retrain_full

* log custom metric

* retraining budget check

* sample size check before retrain

* remove 'time2eval' from result

* report 'total_search_time' in result

* rename total_search_time to wall_clock_time

* rename train_loss boolean to log_training_metric

* set default train_loss to None

* exclude oom result

* log retrained model

* no subsample

* doc str

* notebook

* predicted value is NaN for sarimax

* version

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
Co-authored-by: Qingyun Wu <qxw5138@psu.edu>
This commit is contained in:
Qingyun Wu 2021-08-23 19:36:51 -04:00 committed by GitHub
parent 3d0a3d26a2
commit a229a6112a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 5142 additions and 4677 deletions

View File

@ -13,6 +13,7 @@ from sklearn.model_selection import train_test_split, RepeatedStratifiedKFold, \
RepeatedKFold, GroupKFold, TimeSeriesSplit
from sklearn.utils import shuffle
import pandas as pd
import logging
from .ml import compute_estimator, train_estimator, get_estimator_class, \
get_classification_objective
@ -24,8 +25,6 @@ from .data import concat
from . import tune
from .training_log import training_log_reader, training_log_writer
import logging
logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
'[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s',
@ -56,6 +55,7 @@ class SearchState:
self.cat_hp_cost = {}
self.data_size = data_size
self.ls_ever_converged = False
self.learner_class = learner_class
search_space = learner_class.search_space(
data_size=data_size, task=task)
for name, space in search_space.items():
@ -86,10 +86,7 @@ class SearchState:
self.sample_size = None
self.trial_time = 0
def update(self, analysis, time_used, save_model_history=False):
if not analysis.trials:
return
result = analysis.trials[-1].last_result
def update(self, result, time_used, save_model_history=False):
if result:
config = result['config']
if config and 'FLAML_sample_size' in config:
@ -98,9 +95,8 @@ class SearchState:
self.sample_size = self.data_size
obj = result['val_loss']
train_loss = result['train_loss']
time2eval = result['time2eval']
trained_estimator = result[
'trained_estimator']
time2eval = result['time_total_s']
trained_estimator = result['trained_estimator']
del result['trained_estimator'] # free up RAM
else:
obj, time2eval, trained_estimator = np.inf, 0.0, None
@ -159,11 +155,10 @@ class AutoMLState:
if weight is not None:
sampled_weight = weight[:sample_size]
else:
sampled_X_train = concat(self.X_train, self.X_val)
sampled_y_train = np.concatenate([self.y_train, self.y_val])
weight = self.fit_kwargs.get('sample_weight')
if weight is not None:
sampled_weight = np.concatenate([weight, self.weight_val])
sampled_X_train = self.X_train_all
sampled_y_train = self.y_train_all
if 'sample_weight' in self.fit_kwargs:
sampled_weight = self.sample_weight_all
return sampled_X_train, sampled_y_train, sampled_weight
def _compute_with_config_base(self,
@ -187,7 +182,7 @@ class AutoMLState:
budget = time_left if sample_size == self.data_size else \
time_left / 2 * sample_size / self.data_size
trained_estimator, val_loss, train_loss, time2eval, pred_time = \
trained_estimator, val_loss, train_loss, _, pred_time = \
compute_estimator(
sampled_X_train,
sampled_y_train,
@ -208,7 +203,7 @@ class AutoMLState:
self.fit_kwargs)
result = {
'pred_time': pred_time,
'time2eval': time2eval,
'wall_clock_time': time.time() - self._start_time_flag,
'train_loss': train_loss,
'val_loss': val_loss,
'trained_estimator': trained_estimator
@ -251,6 +246,18 @@ class AutoMLState:
return estimator, train_time
def size(state: AutoMLState, config: dict) -> float:
'''Size function
Returns:
The mem size in bytes for a config
'''
config = config.get('ml', config)
estimator = config['learner']
learner_class = state.learner_classes.get(estimator)
return learner_class.size(config)
class AutoML:
'''The AutoML class
@ -517,11 +524,10 @@ class AutoML:
X_val, y_val = self._state.X_val, self._state.y_val
if issparse(X_val):
X_val = X_val.tocsr()
X_train_all, y_train_all = \
self._X_train_all, self._y_train_all
X_train_all, y_train_all = self._X_train_all, self._y_train_all
if issparse(X_train_all):
X_train_all = X_train_all.tocsr()
if (self._state.task == 'binary:logistic' or self._state.task == 'multi:softmax') \
if self._state.task in ('binary:logistic', 'multi:softmax') \
and self._state.fit_kwargs.get('sample_weight') is None \
and self._split_type != 'time':
# logger.info(f"label {pd.unique(y_train_all)}")
@ -552,12 +558,12 @@ class AutoML:
f"class {label} augmented from {rare_count} to {count}")
SHUFFLE_SPLIT_TYPES = ['uniform', 'stratified']
if self._split_type in SHUFFLE_SPLIT_TYPES:
if 'sample_weight' in self._state.fit_kwargs:
X_train_all, y_train_all, self._state.fit_kwargs[
'sample_weight'] = shuffle(
X_train_all, y_train_all,
self._state.fit_kwargs['sample_weight'],
if self._sample_weight_full is not None:
X_train_all, y_train_all, self._state.sample_weight_all = \
shuffle(X_train_all, y_train_all, self._sample_weight_full,
random_state=RANDOM_SEED)
self._state.fit_kwargs[
'sample_weight'] = self._state.sample_weight_all
elif hasattr(self._state, 'groups') and self._state.groups is not None:
X_train_all, y_train_all, self._state.groups = shuffle(
X_train_all, y_train_all, self._state.groups,
@ -658,12 +664,11 @@ class AutoML:
test_size=split_ratio,
random_state=RANDOM_SEED)
self._state.data_size = X_train.shape[0]
if X_val is None:
self.data_size_full = self._state.data_size
else:
self.data_size_full = self._state.data_size + X_val.shape[0]
self.data_size_full = len(y_train_all)
self._state.X_train, self._state.y_train, self._state.X_val, \
self._state.y_val = (X_train, y_train, X_val, y_val)
self._state.X_train_all = X_train_all
self._state.y_train_all = y_train_all
if hasattr(self._state, 'groups') and self._state.groups is not None:
logger.info("Using GroupKFold")
assert len(self._state.groups) == y_train_all.size, \
@ -789,7 +794,7 @@ class AutoML:
best = reader.get_record(record_id)
else:
for record in reader.records():
time_used = record.total_search_time
time_used = record.wall_clock_time
if time_used > time_budget:
break
training_duration = time_used
@ -941,7 +946,7 @@ class AutoML:
return config
@property
def points_to_evalaute(self) -> dict:
def points_to_evaluate(self) -> dict:
'''Initial points to evaluate
Returns:
@ -999,6 +1004,7 @@ class AutoML:
AutoMLState._compute_with_config_base,
self._state, estimator)
states = self._search_states
mem_res = self._mem_thres
def train(config: dict):
sample_size = config.get('FLAML_sample_size')
@ -1006,28 +1012,20 @@ class AutoML:
if sample_size:
config['FLAML_sample_size'] = sample_size
estimator = config['learner']
# check memory constraints before training
if states[estimator].learner_class.size(config) <= mem_res:
del config['learner']
result = states[estimator].training_function(config)
return result
else:
return {'pred_time': 0,
'wall_clock_time': None,
'train_loss': np.inf,
'val_loss': np.inf,
'trained_estimator': None
}
return train
@property
def size(self) -> Callable[[dict], float]:
'''Size function
Returns:
A function that returns the mem size in bytes for a config
'''
def size_func(config: dict) -> float:
config = config.get('ml', config)
estimator = config['learner']
learner_class = self._state.learner_classes.get(estimator)
return learner_class.size(config)
return size_func
@property
def metric_constraints(self) -> list:
'''Metric constraints
@ -1075,29 +1073,30 @@ class AutoML:
hpo_method=None,
starting_points={},
seed=None,
n_concurrent_trials=1,
**fit_kwargs):
'''Find a model for a given task
Args:
X_train: A numpy array or a pandas dataframe of training data in
shape (n, m)
For 'forecast' task, X_train should be timestamp
y_train: A numpy array or a pandas series of labels in shape (n,)
For 'forecast' task, y_train should be value
dataframe: A dataframe of training data including label column
shape (n, m). For 'forecast' task, X_train should contain a
single column of timestamps.
y_train: A numpy array or a pandas series of labels in shape (n, ).
dataframe: A dataframe of training data including label column.
For 'forecast' task, dataframe must be specified and should
have two columns: timestamp and value
have two columns: timestamp and value.
label: A str of the label column name for 'classification' or
'regression' task or a tuple of strings for timestamp and
value columns for 'forecasting' task
'regression' task, e.g., 'label';
or a tuple of strings for timestamp and value columns for
'forecasting' task, e.g., ('timestamp', 'value').
Note: If X_train and y_train are provided,
dataframe and label are ignored;
If not, dataframe and label must be provided.
metric: A string of the metric name or a function,
e.g., 'accuracy', 'roc_auc', 'roc_auc_ovr', 'roc_auc_ovo',
'f1', 'micro_f1', 'macro_f1', 'log_loss', 'mape', 'mae', 'mse', 'r2'
for 'forecast' task, use 'mape'
if passing a customized metric function, the function needs to
'f1', 'micro_f1', 'macro_f1', 'log_loss', 'mae', 'mse', 'r2',
'mape'.
If passing a customized metric function, the function needs to
have the follwing signature:
.. code-block:: python
@ -1109,11 +1108,11 @@ class AutoML:
return metric_to_minimize, metrics_to_log
which returns a float number as the minimization objective,
and a tuple of floats or a dictionary as the metrics to log
and a tuple of floats or a dictionary as the metrics to log.
task: A string of the task type, e.g.,
'classification', 'regression', 'forecast'
n_jobs: An integer of the number of threads for training
log_file_name: A string of the log file name
'classification', 'regression', 'forecast'.
n_jobs: An integer of the number of threads for training.
log_file_name: A string of the log file name.
estimator_list: A list of strings for estimator names, or 'auto'
e.g.,
@ -1121,51 +1120,60 @@ class AutoML:
['lgbm', 'xgboost', 'catboost', 'rf', 'extra_tree']
time_budget: A float number of the time budget in seconds
max_iter: An integer of the maximal number of iterations
time_budget: A float number of the time budget in seconds.
max_iter: An integer of the maximal number of iterations.
sample: A boolean of whether to sample the training data during
search
search.
eval_method: A string of resampling strategy, one of
['auto', 'cv', 'holdout']
split_ratio: A float of the valiation data percentage for holdout
n_splits: An integer of the number of folds for cross - validation
['auto', 'cv', 'holdout'].
split_ratio: A float of the valiation data percentage for holdout.
n_splits: An integer of the number of folds for cross - validation.
log_type: A string of the log type, one of
['better', 'all']
['better', 'all'].
'better' only logs configs with better loss than previos iters
'all' logs all the tried configs
'all' logs all the tried configs.
model_history: A boolean of whether to keep the history of best
models in the history property. Make sure memory is large
enough if setting to True.
log_training_metric: A boolean of whether to log the training
metric for each model.
mem_thres: A float of the memory size constraint in bytes
pred_time_limit: A float of the prediction latency constraint in seconds
train_time_limit: A float of the training time constraint in seconds
X_val: None or a numpy array or a pandas dataframe of validation data
y_val: None or a numpy array or a pandas series of validation labels
mem_thres: A float of the memory size constraint in bytes.
pred_time_limit: A float of the prediction latency constraint in seconds.
train_time_limit: A float of the training time constraint in seconds.
X_val: None or a numpy array or a pandas dataframe of validation data.
y_val: None or a numpy array or a pandas series of validation labels.
sample_weight_val: None or a numpy array of the sample weight of
validation data.
groups: None or an array-like of shape (n,) | Group labels for the
samples used while splitting the dataset into train/valid set
samples used while splitting the dataset into train/valid set.
verbose: int, default=1 | Controls the verbosity, higher means more
messages.
retrain_full: bool or str, default=True | whether to retrain the
selected model on the full training data when using holdout.
True - retrain only after search finishes; False - no retraining;
'budget' - do best effort to retrain without violating the time
budget.
hpo_method: str or None, default=None | The hyperparameter
optimization method. When it is None, CFO is used.
No need to set when using flaml's default search space or using
a simple customized search space. When set to 'bs', BlendSearch
is used. BlendSearch can be tried when the search space is
complex, for example, containing multiple disjoint, discontinuous
subspaces.
subspaces. When set to 'random' and the argument 'n_concurrent_trials'
is larger than 1, RandomSearch is used.
starting_points: A dictionary to specify the starting hyperparameter
config for the estimators.
Keys are the name of the estimators, and values are the starting
hyperparamter configurations for the corresponding estimators.
seed: int or None, default=None | The random seed for np.random.
n_concurrent_trials: [Experimental] int, default=1 | The number of
concurrent trials. For n_concurrent_trials > 1, installation of
ray is required: `pip install flaml[ray]`.
**fit_kwargs: Other key word arguments to pass to fit() function of
the searched learners, such as sample_weight. Include period as
a key word argument for 'forecast' task.
'''
self._start_time_flag = time.time()
self._state._start_time_flag = self._start_time_flag = time.time()
self._state.task = task
self._state.log_training_metric = log_training_metric
self._state.fit_kwargs = fit_kwargs
@ -1194,10 +1202,12 @@ class AutoML:
self._split_type = "uniform"
elif self._state.task == 'forecast':
if split_type is not None and split_type != 'time':
raise ValueError("split_type must be 'time' when task is 'forecast'. ")
raise ValueError(
"split_type must be 'time' when task is 'forecast'.")
self._split_type = "time"
if self._state.task == 'forecast' and self._state.fit_kwargs.get('period') is None:
raise TypeError("missing 1 required argument for 'forecast' task: 'period'. ")
if self._state.fit_kwargs.get('period') is None:
raise TypeError(
"missing 1 required argument for 'forecast' task: 'period'.")
if eval_method == 'auto' or self._state.X_val is not None:
eval_method = self._decide_eval_method(time_budget)
self._state.eval_method = eval_method
@ -1208,13 +1218,16 @@ class AutoML:
logger.addHandler(_ch)
logger.info("Evaluation method: {}".format(eval_method))
self._retrain_full = retrain_full and (
self._retrain_in_budget = retrain_full == 'budget' and (
eval_method == 'holdout' and self._state.X_val is None)
self._retrain_final = retrain_full is True and (
eval_method == 'holdout' and self._state.X_val is None) or (
eval_method == 'cv')
if self._state.task != 'forecast':
self._prepare_data(eval_method, split_ratio, n_splits)
else:
self._prepare_data(eval_method, split_ratio, n_splits,
period=self._state.fit_kwargs.get('period'))
period=self._state.fit_kwargs['period'])
self._sample = sample and eval_method != 'cv' and (
MIN_SAMPLE_TRAIN * SAMPLE_MULTIPLY_FACTOR < self._state.data_size)
if 'auto' == metric:
@ -1237,11 +1250,13 @@ class AutoML:
logger.info(f'Minimizing error metric: {error_metric}')
if 'auto' == estimator_list:
estimator_list = ['lgbm', 'rf', 'catboost', 'xgboost', 'extra_tree']
if 'regression' != self._state.task:
estimator_list += ['lrl1']
if self._state.task == 'forecast':
estimator_list = ['fbprophet', 'arima', 'sarimax']
else:
estimator_list = [
'lgbm', 'rf', 'catboost', 'xgboost', 'extra_tree']
if 'regression' != self._state.task:
estimator_list += ['lrl1']
for estimator_name in estimator_list:
if estimator_name not in self._state.learner_classes:
self.add_learner(
@ -1271,6 +1286,7 @@ class AutoML:
self.split_ratio = split_ratio
self._save_model_history = model_history
self._state.n_jobs = n_jobs
self._n_concurrent_trials = n_concurrent_trials
if log_file_name:
with training_log_writer(log_file_name) as save_helper:
self._training_log = save_helper
@ -1278,10 +1294,13 @@ class AutoML:
else:
self._training_log = None
self._search()
if self._best_estimator:
logger.info("fit succeeded")
logger.info(f"Time taken to find the best model: {self._time_taken_best_iter}")
if self._time_taken_best_iter >= time_budget * 0.7 and not \
all(self._ever_converged_per_learner.values()):
if self._time_taken_best_iter >= time_budget * 0.7 and not all(
state.search_alg and state.search_alg.searcher.is_ls_ever_converged
for state in self._search_states.values()
):
logger.warn("Time taken to find the best model is {0:.0f}% of the "
"provided time budget and not all estimators' hyperparameter "
"search converged. Consider increasing the time budget.".format(
@ -1290,32 +1309,116 @@ class AutoML:
if verbose == 0:
logger.setLevel(old_level)
def _search(self):
# initialize the search_states
self._eci = []
self._state.best_loss = float('+inf')
self._state.time_from_start = 0
self._estimator_index = None
self._best_iteration = 0
self._time_taken_best_iter = 0
self._model_history = {}
self._config_history = {}
self._max_iter_per_learner = 1000000 # TODO
self._iter_per_learner = dict([(e, 0) for e in self.estimator_list])
self._ever_converged_per_learner = dict([(e, False) for e in self.estimator_list])
self._fullsize_reached = False
self._trained_estimator = None
self._best_estimator = None
self._retrained_config = {}
self._warn_threshold = 10
def _search_parallel(self):
try:
from ray import __version__ as ray_version
assert ray_version >= '1.0.0'
import ray
from ray.tune.suggest import ConcurrencyLimiter
except (ImportError, AssertionError):
raise ImportError(
"n_concurrent_trial > 1 requires installation of ray. "
"Please run pip install flaml[ray]")
if self._hpo_method in ('cfo', 'grid'):
from flaml import CFO as SearchAlgo
elif 'optuna' == self._hpo_method:
from ray.tune.suggest.optuna import OptunaSearch as SearchAlgo
elif 'bs' == self._hpo_method:
from flaml import BlendSearch as SearchAlgo
elif 'cfocat' == self._hpo_method:
from flaml.searcher.cfo_cat import CFOCat as SearchAlgo
elif 'random' == self._hpo_method:
from ray.tune.suggest import BasicVariantGenerator as SearchAlgo
from ray.tune.sample import Domain as RayDomain
from .tune.sample import Domain
else:
raise NotImplementedError(
f"hpo_method={self._hpo_method} is not recognized. "
"'cfo' and 'bs' are supported.")
if self._hpo_method == 'random':
# Any point in points_to_evaluate must consist of hyperparamters
# that are tunable, which can be identified by checking whether
# the corresponding value in the search space is an instance of
# the 'Domain' class from flaml or ray.tune
points_to_evaluate = self.points_to_evaluate.copy()
to_del = []
for k, v in self.search_space.items():
if not (isinstance(v, Domain) or isinstance(v, RayDomain)):
to_del.append(k)
for k in to_del:
for p in points_to_evaluate:
del p[k]
est_retrain_time = next_trial_time = 0
best_config_sig = None
# use ConcurrencyLimiter to limit the amount of concurrency when
# using a search algorithm
better = True # whether we find a better model in one trial
if self._ensemble:
self.best_model = {}
search_alg = SearchAlgo(max_concurrent=self._n_concurrent_trials,
points_to_evaluate=points_to_evaluate
)
else:
search_alg = SearchAlgo(
metric='val_loss',
space=self.search_space,
low_cost_partial_config=self.low_cost_partial_config,
points_to_evaluate=self.points_to_evaluate,
cat_hp_cost=self.cat_hp_cost,
prune_attr=self.prune_attr,
min_resource=self.min_resource,
max_resource=self.max_resource,
config_constraints=[(partial(size, self._state), '<=', self._mem_thres)],
metric_constraints=self.metric_constraints)
search_alg = ConcurrencyLimiter(search_alg, self._n_concurrent_trials)
self._state.time_from_start = time.time() - self._start_time_flag
time_left = self._state.time_budget - self._state.time_from_start
search_alg.set_search_properties(None, None, config={
'time_budget_s': time_left})
resources_per_trial = {
"cpu": self._state.n_jobs} if self._state.n_jobs > 1 else None
analysis = ray.tune.run(
self.trainable, search_alg=search_alg, config=self.search_space,
metric='val_loss', mode='min', resources_per_trial=resources_per_trial,
time_budget_s=self._state.time_budget, num_samples=self._max_iter)
# logger.info([trial.last_result for trial in analysis.trials])
trials = sorted((trial for trial in analysis.trials if trial.last_result
and trial.last_result['wall_clock_time'] is not None),
key=lambda x: x.last_result['wall_clock_time'])
for _track_iter, trial in enumerate(trials):
result = trial.last_result
better = False
if result:
config = result['config']
estimator = config.get('ml', config)['learner']
search_state = self._search_states[estimator]
search_state.update(result, 0, self._save_model_history)
if result['wall_clock_time'] is not None:
self._state.time_from_start = result['wall_clock_time']
if search_state.sample_size == self._state.data_size:
self._iter_per_learner[estimator] += 1
if not self._fullsize_reached:
self._fullsize_reached = True
if search_state.best_loss < self._state.best_loss:
self._state.best_loss = search_state.best_loss
self._best_estimator = estimator
self._config_history[_track_iter] = (
self._best_estimator, config, self._time_taken_best_iter)
if self._save_model_history:
self._model_history[_track_iter] = search_state.trained_estimator
self._trained_estimator = search_state.trained_estimator
self._best_iteration = _track_iter
self._time_taken_best_iter = self._state.time_from_start
better = True
self._search_states[estimator].best_config = config
if (better or self._log_type == 'all') and self._training_log:
self._training_log.append(
self._iter_per_learner[estimator],
search_state.train_loss,
search_state.trial_time,
self._state.time_from_start,
search_state.val_loss,
config,
self._state.best_loss,
search_state.best_config,
estimator,
search_state.sample_size)
def _search_sequential(self):
try:
from ray import __version__ as ray_version
assert ray_version >= '1.0.0'
@ -1339,6 +1442,11 @@ class AutoML:
f"hpo_method={self._hpo_method} is not recognized. "
"'cfo' and 'bs' are supported.")
est_retrain_time = next_trial_time = 0
best_config_sig = None
better = True # whether we find a better model in one trial
if self._ensemble:
self.best_model = {}
for self._track_iter in range(self._max_iter):
if self._estimator_index is None:
estimator = self._active_estimators[0]
@ -1351,7 +1459,7 @@ class AutoML:
search_state = self._search_states[estimator]
self._state.time_from_start = time.time() - self._start_time_flag
time_left = self._state.time_budget - self._state.time_from_start
budget_left = time_left if not self._retrain_full or better or (
budget_left = time_left if not self._retrain_in_budget or better or (
not self.best_estimator) or self._search_states[
self.best_estimator].sample_size < self._state.data_size \
else time_left - est_retrain_time
@ -1404,6 +1512,7 @@ class AutoML:
)
search_state.search_alg = ConcurrencyLimiter(algo,
max_concurrent=1)
# search_state.search_alg = algo
else:
search_space = None
if self._hpo_method in ('bs', 'cfo', 'cfocat'):
@ -1423,7 +1532,9 @@ class AutoML:
time_used = time.time() - start_run_time
better = False
if analysis.trials:
search_state.update(analysis, time_used=time_used,
result = analysis.trials[-1].last_result
search_state.update(result,
time_used=time_used,
save_model_history=self._save_model_history)
if self._estimator_index is None:
eci_base = search_state.init_eci
@ -1432,7 +1543,8 @@ class AutoML:
self._eci.append(self._search_states[e].init_eci
/ eci_base * self._eci[0])
self._estimator_index = 0
self._state.time_from_start = time.time() - self._start_time_flag
if result['wall_clock_time'] is not None:
self._state.time_from_start = result['wall_clock_time']
# logger.info(f"{self._search_states[estimator].sample_size}, {data_size}")
if search_state.sample_size == self._state.data_size:
self._iter_per_learner[estimator] += 1
@ -1483,7 +1595,7 @@ class AutoML:
search_state.train_loss)
mlflow.log_metric('trial_time',
search_state.trial_time)
mlflow.log_metric('total_search_time',
mlflow.log_metric('wall_clock_time',
self._state.time_from_start)
mlflow.log_metric('validation_loss',
search_state.val_loss)
@ -1506,11 +1618,10 @@ class AutoML:
search_state.best_loss,
self._best_estimator,
self._state.best_loss))
searcher = search_state.search_alg.searcher
if searcher.is_ls_ever_converged and not self._ever_converged_per_learner[estimator]:
self._ever_converged_per_learner[estimator] = searcher.is_ls_ever_converged
if all(self._ever_converged_per_learner.values()) and \
self._state.time_from_start > self._warn_threshold * self._time_taken_best_iter:
if all(state.search_alg and state.search_alg.searcher.is_ls_ever_converged
for state in self._search_states.values()) and (
self._state.time_from_start
> self._warn_threshold * self._time_taken_best_iter):
logger.warn("All estimator hyperparameters local search has converged at least once, "
f"and the total search time exceeds {self._warn_threshold} times the time taken "
"to find the best model.")
@ -1520,10 +1631,10 @@ class AutoML:
if self._estimator_index is not None:
self._active_estimators.remove(estimator)
self._estimator_index -= 1
if self._retrain_full and best_config_sig and not better and (
self._search_states[
self._best_estimator].sample_size == self._state.data_size
) and (est_retrain_time
if self._retrain_in_budget and best_config_sig and est_retrain_time \
and not better and self._search_states[
self._best_estimator].sample_size == self._state.data_size and (
est_retrain_time
<= self._state.time_budget - self._state.time_from_start
<= est_retrain_time + next_trial_time):
self._trained_estimator, \
@ -1532,7 +1643,7 @@ class AutoML:
self._search_states[self._best_estimator].best_config,
self.data_size_full)
logger.info("retrain {} for {:.1f}s".format(
estimator, retrain_time))
self._best_estimator, retrain_time))
self._retrained_config[best_config_sig] = retrain_time
est_retrain_time = 0
self._state.time_from_start = time.time() - self._start_time_flag
@ -1545,12 +1656,34 @@ class AutoML:
self._best_estimator].time2eval_best
if time_left < time_ensemble < 2 * time_left:
break
def _search(self):
# initialize the search_states
self._eci = []
self._state.best_loss = float('+inf')
self._state.time_from_start = 0
self._estimator_index = None
self._best_iteration = 0
self._time_taken_best_iter = 0
self._model_history = {}
self._config_history = {}
self._max_iter_per_learner = 1000000 # TODO
self._iter_per_learner = dict([(e, 0) for e in self.estimator_list])
self._fullsize_reached = False
self._trained_estimator = None
self._best_estimator = None
self._retrained_config = {}
self._warn_threshold = 10
if self._n_concurrent_trials == 1:
self._search_sequential()
else:
self._search_parallel()
# Add a checkpoint for the current best config to the log.
if self._training_log:
self._training_log.checkpoint()
if self._best_estimator:
self._selected = self._search_states[self._best_estimator]
self._trained_estimator = self._selected.trained_estimator
self.modelcount = sum(
search_state.total_iter
for search_state in self._search_states.values())
@ -1585,6 +1718,25 @@ class AutoML:
logger.info(f'ensemble: {stacker}')
self._trained_estimator = stacker
self._trained_estimator.model = stacker
elif self._retrain_final:
# reset time budget for retraining
self._state.time_from_start -= self._state.time_budget
if (self._state.time_budget - self._state.time_from_start
> self._selected.est_retrain_time(self.data_size_full)) \
and self._selected.best_config_sample_size == self._state.data_size:
self._trained_estimator, \
retrain_time = self._state._train_with_config(
self._best_estimator,
self._search_states[self._best_estimator].best_config,
self.data_size_full)
logger.info("retrain {} for {:.1f}s".format(
self._best_estimator, retrain_time))
if self._trained_estimator:
logger.info(
f'retrained model: {self._trained_estimator.model}')
else:
logger.info(
"not retraining because the time budget is too small.")
else:
self._selected = self._trained_estimator = None
self.modelcount = 0

View File

@ -141,14 +141,14 @@ def get_output_from_log(filename, time_budget):
best_config_list = []
with training_log_reader(filename) as reader:
for record in reader.records():
time_used = record.total_search_time
time_used = record.wall_clock_time
val_loss = record.validation_loss
config = record.config
learner = record.learner.split('_')[0]
sample_size = record.sample_size
train_loss = record.logged_metric
if time_used < time_budget:
if time_used < time_budget and np.isfinite(val_loss):
if val_loss < best_val_loss:
best_val_loss = val_loss
best_config = config

View File

@ -102,8 +102,11 @@ def sklearn_metric_loss_score(
score = log_loss(
y_true, y_predict, labels=labels, sample_weight=sample_weight)
elif 'mape' in metric_name:
try:
score = mean_absolute_percentage_error(
y_true, y_predict)
except ValueError:
return np.inf
elif 'micro_f1' in metric_name:
score = 1 - f1_score(
y_true, y_predict, sample_weight=sample_weight, average='micro')
@ -141,21 +144,23 @@ def get_y_pred(estimator, X, eval_metric, obj, freq=None):
def get_test_loss(
estimator, X_train, y_train, X_test, y_test, weight_test,
eval_metric, obj, labels=None, budget=None, train_loss=False, fit_kwargs={}
eval_metric, obj, labels=None, budget=None, log_training_metric=False, fit_kwargs={}
):
start = time.time()
train_time = estimator.fit(X_train, y_train, budget, **fit_kwargs)
estimator.fit(X_train, y_train, budget, **fit_kwargs)
if isinstance(eval_metric, str):
pred_start = time.time()
test_pred_y = get_y_pred(estimator, X_test, eval_metric, obj)
pred_time = (time.time() - pred_start) / X_test.shape[0]
test_loss = sklearn_metric_loss_score(eval_metric, test_pred_y, y_test,
labels, weight_test)
if train_loss is not False:
if log_training_metric:
test_pred_y = get_y_pred(estimator, X_train, eval_metric, obj)
train_loss = sklearn_metric_loss_score(
eval_metric, test_pred_y,
y_train, labels, fit_kwargs.get('sample_weight'))
else:
train_loss = None
else: # customized metric function
test_loss, metrics = eval_metric(
X_test, y_test, estimator, labels, X_train, y_train,
@ -174,40 +179,41 @@ def train_model(estimator, X_train, y_train, budget, fit_kwargs={}):
def evaluate_model(
estimator, X_train, y_train, X_val, y_val, weight_val,
budget, kf, task, eval_method, eval_metric, best_val_loss, train_loss=False,
budget, kf, task, eval_method, eval_metric, best_val_loss, log_training_metric=False,
fit_kwargs={}
):
if 'holdout' in eval_method:
val_loss, train_loss, train_time, pred_time = evaluate_model_holdout(
estimator, X_train, y_train, X_val, y_val, weight_val, budget,
task, eval_metric, train_loss=train_loss,
task, eval_metric, log_training_metric=log_training_metric,
fit_kwargs=fit_kwargs)
else:
val_loss, train_loss, train_time, pred_time = evaluate_model_CV(
estimator, X_train, y_train, budget, kf, task,
eval_metric, best_val_loss, train_loss=train_loss,
eval_metric, best_val_loss, log_training_metric=log_training_metric,
fit_kwargs=fit_kwargs)
return val_loss, train_loss, train_time, pred_time
def evaluate_model_holdout(
estimator, X_train, y_train, X_val, y_val,
weight_val, budget, task, eval_metric, train_loss=False,
weight_val, budget, task, eval_metric, log_training_metric=False,
fit_kwargs={}
):
val_loss, train_time, train_loss, pred_time = get_test_loss(
estimator, X_train, y_train, X_val, y_val, weight_val, eval_metric,
task, budget=budget, train_loss=train_loss, fit_kwargs=fit_kwargs)
task, budget=budget, log_training_metric=log_training_metric, fit_kwargs=fit_kwargs)
return val_loss, train_loss, train_time, pred_time
def evaluate_model_CV(
estimator, X_train_all, y_train_all, budget, kf,
task, eval_metric, best_val_loss, train_loss=False, fit_kwargs={}
task, eval_metric, best_val_loss, log_training_metric=False, fit_kwargs={}
):
start_time = time.time()
total_val_loss = 0
total_train_loss = None
train_loss = None
train_time = pred_time = 0
valid_fold_num = total_fold_num = 0
n = kf.get_n_splits()
@ -231,7 +237,7 @@ def evaluate_model_CV(
kf = kf.split(X_train_split)
rng = np.random.RandomState(2020)
val_loss_list = []
budget_per_train = budget / (n + 1)
budget_per_train = budget / n
if 'sample_weight' in fit_kwargs:
weight = fit_kwargs['sample_weight']
weight_val = None
@ -259,13 +265,13 @@ def evaluate_model_CV(
val_loss_i, train_time_i, train_loss_i, pred_time_i = get_test_loss(
estimator, X_train, y_train, X_val, y_val, weight_val,
eval_metric, task, labels, budget_per_train,
train_loss=train_loss, fit_kwargs=fit_kwargs)
log_training_metric=log_training_metric, fit_kwargs=fit_kwargs)
if weight is not None:
fit_kwargs['sample_weight'] = weight
valid_fold_num += 1
total_fold_num += 1
total_val_loss += val_loss_i
if train_loss is not False:
if log_training_metric or not isinstance(eval_metric, str):
if isinstance(total_train_loss, list):
total_train_loss = [
total_train_loss[i] + v for i, v in enumerate(train_loss_i)]
@ -286,7 +292,7 @@ def evaluate_model_CV(
break
val_loss = np.max(val_loss_list)
n = total_fold_num
if train_loss is not False:
if log_training_metric or not isinstance(eval_metric, str):
if isinstance(total_train_loss, list):
train_loss = [v / n for v in total_train_loss]
elif isinstance(total_train_loss, dict):
@ -294,17 +300,17 @@ def evaluate_model_CV(
else:
train_loss = total_train_loss / n
pred_time /= n
budget -= time.time() - start_time
if val_loss < best_val_loss and budget > budget_per_train:
estimator.cleanup()
estimator.fit(X_train_all, y_train_all, budget, **fit_kwargs)
# budget -= time.time() - start_time
# if val_loss < best_val_loss and budget > budget_per_train:
# estimator.cleanup()
# estimator.fit(X_train_all, y_train_all, budget, **fit_kwargs)
return val_loss, train_loss, train_time, pred_time
def compute_estimator(
X_train, y_train, X_val, y_val, weight_val, budget, kf,
config_dic, task, estimator_name, eval_method, eval_metric,
best_val_loss=np.Inf, n_jobs=1, estimator_class=None, train_loss=False,
best_val_loss=np.Inf, n_jobs=1, estimator_class=None, log_training_metric=False,
fit_kwargs={}
):
estimator_class = estimator_class or get_estimator_class(
@ -313,7 +319,7 @@ def compute_estimator(
**config_dic, task=task, n_jobs=n_jobs)
val_loss, train_loss, train_time, pred_time = evaluate_model(
estimator, X_train, y_train, X_val, y_val, weight_val, budget, kf, task,
eval_method, eval_metric, best_val_loss, train_loss=train_loss,
eval_method, eval_metric, best_val_loss, log_training_metric=log_training_metric,
fit_kwargs=fit_kwargs)
return estimator, val_loss, train_loss, train_time, pred_time

View File

@ -222,10 +222,10 @@ class LGBMEstimator(BaseEstimator):
'domain': tune.loguniform(lower=1 / 1024, upper=1.0),
'init_value': 0.1,
},
'subsample': {
'domain': tune.uniform(lower=0.1, upper=1.0),
'init_value': 1.0,
},
# 'subsample': {
# 'domain': tune.uniform(lower=0.1, upper=1.0),
# 'init_value': 1.0,
# },
'log_max_bin': {
'domain': tune.lograndint(lower=3, upper=11),
'init_value': 8,
@ -252,6 +252,7 @@ class LGBMEstimator(BaseEstimator):
def __init__(self, task='binary:logistic', log_max_bin=8, **params):
super().__init__(task, **params)
if "objective" not in self.params:
# Default: regression for LGBMRegressor,
# binary or multiclass for LGBMClassifier
if 'regression' in task:
@ -262,18 +263,19 @@ class LGBMEstimator(BaseEstimator):
objective = 'multiclass'
else:
objective = 'regression'
self.params["objective"] = objective
if "n_estimators" in self.params:
self.params["n_estimators"] = int(round(self.params["n_estimators"]))
if "num_leaves" in self.params:
self.params["num_leaves"] = int(round(self.params["num_leaves"]))
if "min_child_samples" in self.params:
self.params["min_child_samples"] = int(round(self.params["min_child_samples"]))
if "objective" not in self.params:
self.params["objective"] = objective
if "max_bin" not in self.params:
self.params['max_bin'] = 1 << int(round(log_max_bin)) - 1
if "verbose" not in self.params:
self.params['verbose'] = -1
# if "subsample_freq" not in self.params:
# self.params['subsample_freq'] = 1
if 'regression' in task:
self.estimator_class = LGBMRegressor
else:

View File

@ -748,6 +748,7 @@ class AutoTransformers:
self._set_metric(custom_metric_name, custom_metric_mode_name)
self._set_task()
self._fp16 = fp16
ray.shutdown()
ray.init(local_mode=ray_local_mode)
self._set_search_space(**custom_hpo_args)

View File

@ -3,6 +3,7 @@
* Licensed under the MIT License. See LICENSE file in the
* project root for license information.
'''
from flaml.tune.sample import Domain
from typing import Dict, Optional, Tuple
import numpy as np
try:
@ -140,7 +141,7 @@ class FLOW2(Searcher):
if str(sampler) != 'Normal':
self._bounded_keys.append(key)
if not hier:
self._space_keys = sorted(self._space.keys())
self._space_keys = sorted(self._tunable_keys)
self._hierarchical = hier
if (self.prune_attr and self.prune_attr not in self._space
and self.max_resource):
@ -499,15 +500,25 @@ class FLOW2(Searcher):
else:
space = self._space
value_list = []
# self._space_keys doesn't contain keys with const values,
# e.g., "eval_metric": ["logloss", "error"].
keys = sorted(config.keys()) if self._hierarchical else self._space_keys
for key in keys:
value = config[key]
if key == self.prune_attr:
value_list.append(value)
# else key must be in self.space
# get rid of list type or constant,
# e.g., "eval_metric": ["logloss", "error"]
elif isinstance(space[key], sample.Integer):
else:
# key must be in space
domain = space[key]
if self._hierarchical:
# can't remove constant for hierarchical search space,
# e.g., learner
if not (domain is None or type(domain) in (str, int, float)
or isinstance(domain, sample.Domain)):
# not domain or hashable
# get rid of list type for hierarchical search space.
continue
if isinstance(domain, sample.Integer):
value_list.append(int(round(value)))
else:
value_list.append(value)

View File

@ -16,7 +16,7 @@ class TrainingLogRecord(object):
iter_per_learner: int,
logged_metric: float,
trial_time: float,
total_search_time: float,
wall_clock_time: float,
validation_loss,
config,
best_validation_loss,
@ -27,7 +27,7 @@ class TrainingLogRecord(object):
self.iter_per_learner = iter_per_learner
self.logged_metric = logged_metric
self.trial_time = trial_time
self.total_search_time = total_search_time
self.wall_clock_time = wall_clock_time
self.validation_loss = validation_loss
self.config = config
self.best_validation_loss = best_validation_loss
@ -71,7 +71,7 @@ class TrainingLogWriter(object):
it_counter: int,
train_loss: float,
trial_time: float,
total_search_time: float,
wall_clock_time: float,
validation_loss,
config,
best_validation_loss,
@ -86,7 +86,7 @@ class TrainingLogWriter(object):
it_counter,
train_loss,
trial_time,
total_search_time,
wall_clock_time,
validation_loss,
config,
best_validation_loss,
@ -95,6 +95,7 @@ class TrainingLogWriter(object):
sample_size)
if validation_loss < self.current_best_loss or \
validation_loss == self.current_best_loss and \
self.current_sample_size is not None and \
sample_size > self.current_sample_size:
self.current_best_loss = validation_loss
self.current_sample_size = sample_size

View File

@ -363,6 +363,7 @@ def indexof(domain: Dict, config: Dict) -> int:
continue
# print(domain.const[i])
if all(config[key] == value for key, value in domain.const[i].items()):
# assumption: the concatenation of constants is a unique identifier
return i
return None

View File

@ -1 +1 @@
__version__ = "0.5.13"
__version__ = "0.6.0"

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -10,7 +10,7 @@ from datetime import datetime
from flaml import AutoML
from flaml.data import get_output_from_log
from flaml.model import SKLearnEstimator, XGBoostEstimator
from flaml.model import LGBMEstimator, SKLearnEstimator, XGBoostEstimator
from rgf.sklearn import RGFClassifier, RGFRegressor
from flaml import tune
@ -92,6 +92,24 @@ class MyXGB2(XGBoostEstimator):
super().__init__(objective='reg:squarederror', **params)
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
'n_estimators': {
'domain': tune.lograndint(lower=4, upper=32768),
'init_value': 32768,
'low_cost_init_value': 4,
},
'num_leaves': {
'domain': tune.lograndint(lower=4, upper=32768),
'init_value': 32768,
'low_cost_init_value': 4,
},
}
def custom_metric(X_test, y_test, estimator, labels, X_train, y_train,
weight_test=None, weight_train=None):
from sklearn.metrics import log_loss
@ -477,6 +495,66 @@ class TestAutoML(unittest.TestCase):
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
def test_parallel_xgboost(self, hpo_method=None):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 10,
"metric": 'ap',
"task": 'classification',
"log_file_name": "test/sparse_classification.log",
"estimator_list": ["xgboost"],
"log_type": "all",
"n_jobs": 1,
"n_concurrent_trials": 2,
"hpo_method": hpo_method,
}
X_train = scipy.sparse.eye(900000)
y_train = np.random.randint(2, size=900000)
try:
automl_experiment.fit(X_train=X_train, y_train=y_train,
**automl_settings)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.model_history)
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
except ImportError:
return
def test_parallel_xgboost_random(self):
# use random search as the hpo_method
self.test_parallel_xgboost(hpo_method='random')
def test_random_out_of_memory(self):
automl_experiment = AutoML()
automl_experiment.add_learner(learner_name='large_lgbm', learner_class=MyLargeLGBM)
automl_settings = {
"time_budget": 2,
"metric": 'ap',
"task": 'classification',
"log_file_name": "test/sparse_classification_oom.log",
"estimator_list": ["large_lgbm"],
"log_type": "all",
"n_jobs": 1,
"n_concurrent_trials": 2,
"hpo_method": 'random',
}
X_train = scipy.sparse.eye(900000)
y_train = np.random.randint(2, size=900000)
try:
automl_experiment.fit(X_train=X_train, y_train=y_train,
**automl_settings)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.model_history)
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
except ImportError:
return
def test_sparse_matrix_lr(self):
automl_experiment = AutoML()
automl_settings = {

View File

@ -17,6 +17,7 @@ def test_automl(budget=5, dataset_format='dataframe'):
"metric": 'accuracy', # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"task": 'classification', # task type
"log_file_name": 'airlines_experiment.log', # flaml log file
"seed": 7654321, # random seed
}
'''The main flaml automl API'''
automl.fit(X_train=X_train, y_train=y_train, **settings)

View File

@ -45,7 +45,7 @@ class TestLogging(unittest.TestCase):
**automl_settings)
logger.info(automl.search_space)
logger.info(automl.low_cost_partial_config)
logger.info(automl.points_to_evalaute)
logger.info(automl.points_to_evaluate)
logger.info(automl.cat_hp_cost)
import optuna as ot
study = ot.create_study()
@ -62,16 +62,18 @@ class TestLogging(unittest.TestCase):
config['learner'] = automl.best_estimator
automl.trainable({"ml": config})
from flaml import tune, CFO
from flaml.automl import size
from functools import partial
search_alg = CFO(
metric='val_loss',
space=automl.search_space,
low_cost_partial_config=automl.low_cost_partial_config,
points_to_evaluate=automl.points_to_evalaute,
points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost,
prune_attr=automl.prune_attr,
min_resource=automl.min_resource,
max_resource=automl.max_resource,
config_constraints=[(automl.size, '<=', automl._mem_thres)],
config_constraints=[(partial(size, automl._state), '<=', automl._mem_thres)],
metric_constraints=automl.metric_constraints)
analysis = tune.run(
automl.trainable, search_alg=search_alg, # verbose=2,

View File

@ -40,6 +40,7 @@ def test_simple(method=None):
"n_jobs": 1,
"hpo_method": method,
"log_type": "all",
"retrain_full": "budget",
"time_budget": 1
}
from sklearn.externals._arff import ArffException
@ -53,21 +54,23 @@ def test_simple(method=None):
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl.estimator_list)
print(automl.search_space)
print(automl.points_to_evalaute)
print(automl.points_to_evaluate)
config = automl.best_config.copy()
config['learner'] = automl.best_estimator
automl.trainable(config)
from flaml import tune
from flaml.automl import size
from functools import partial
analysis = tune.run(
automl.trainable, automl.search_space, metric='val_loss', mode="min",
low_cost_partial_config=automl.low_cost_partial_config,
points_to_evaluate=automl.points_to_evalaute,
points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost,
prune_attr=automl.prune_attr,
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,
config_constraints=[(automl.size, '<=', automl._mem_thres)],
config_constraints=[(partial(size, automl._state), '<=', automl._mem_thres)],
metric_constraints=automl.metric_constraints, num_samples=5)
print(analysis.trials[-1])

View File

@ -27,6 +27,8 @@ def test_blendsearch_tune(smoke_test=True):
except ImportError:
print('ray[tune] is not installed, skipping test')
return
import numpy as np
algo = BlendSearch()
algo = ConcurrencyLimiter(algo, max_concurrent=4)
scheduler = AsyncHyperBandScheduler()
@ -42,7 +44,8 @@ def test_blendsearch_tune(smoke_test=True):
"width": tune.uniform(0, 20),
"height": tune.uniform(-100, 100),
# This is an ignored parameter.
"activation": tune.choice(["relu", "tanh"])
"activation": tune.choice(["relu", "tanh"]),
"test4": np.zeros((3, 1)),
})
print("Best hyperparameters found were: ", analysis.best_config)

View File

@ -63,6 +63,7 @@ def _test_xgboost(method='BlendSearch'):
time_budget_s = 60
for n_cpu in [4]:
start_time = time.time()
ray.shutdown()
ray.init(num_cpus=n_cpu, num_gpus=0)
# ray.init(address='auto')
if method == 'BlendSearch':