Make NLP tasks available from AutoML.fit() (#210)

Sequence classification and regression: "seq-classification" and "seq-regression"

Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Xueqing Liu 2021-11-16 14:06:20 -05:00 committed by GitHub
parent 59083fbdcb
commit 42de3075e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 2610 additions and 429 deletions

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
# Project
# Project
/.vs
.vscode

View File

@ -158,6 +158,36 @@ automl.fit(
)
```
* Fine tuning language model
```python
from flaml import AutoML
from datasets import load_dataset
train_dataset = load_dataset("glue", "mrpc", split="train").to_pandas()
dev_dataset = load_dataset("glue", "mrpc", split="validation").to_pandas()
test_dataset = load_dataset("glue", "mrpc", split="test").to_pandas()
custom_sent_keys = ["sentence1", "sentence2"]
label_key = "label"
X_train, y_train = train_dataset[custom_sent_keys], train_dataset[label_key]
X_val, y_val = dev_dataset[custom_sent_keys], dev_dataset[label_key]
X_test = test_dataset[custom_sent_keys]
automl = AutoML()
automl_settings = {
"max_iter": 3,
"time_budget": 100,
"model_history": True,
"task": "seq-classification"
}
automl_settings["custom_hpo_args"] = {
"output_dir": "data/output/",
}
automl.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings)
```
More examples can be found in [notebooks](https://github.com/microsoft/FLAML/tree/main/notebook/).
## Documentation

View File

@ -44,6 +44,11 @@ Online AutoML
.. autoclass:: flaml.AutoVW
:members:
NLP
---
.. autoclass:: flaml.nlp.HPOArgs
:members:
.. Indices and tables
.. ==================

View File

@ -18,6 +18,9 @@ from sklearn.model_selection import (
from sklearn.utils import shuffle
import pandas as pd
import logging
from typing import List, Union
from pandas import DataFrame
from .nlp.utils import _is_nlp_task
from .ml import (
compute_estimator,
@ -35,7 +38,8 @@ from .config import (
N_SPLITS,
SAMPLE_MULTIPLY_FACTOR,
)
from .data import concat, CLASSIFICATION, TS_FORECAST, FORECAST
from .data import concat, CLASSIFICATION, TS_FORECAST, FORECAST, REGRESSION
from . import tune
from .training_log import training_log_reader, training_log_writer
@ -106,6 +110,7 @@ class SearchState:
self.trial_time = 0
def update(self, result, time_used, save_model_history=False):
if result:
config = result["config"]
if config and "FLAML_sample_size" in config:
@ -117,9 +122,14 @@ class SearchState:
time2eval = result["time_total_s"]
trained_estimator = result["trained_estimator"]
del result["trained_estimator"] # free up RAM
n_iter = trained_estimator and trained_estimator.params.get("n_estimators")
if n_iter is not None and "n_estimators" in config:
config["n_estimators"] = n_iter
n_iter = (
trained_estimator
and hasattr(trained_estimator, "ITER_HP")
and trained_estimator.params[trained_estimator.ITER_HP]
)
if n_iter:
config[trained_estimator.ITER_HP] = n_iter
else:
obj, time2eval, trained_estimator = np.inf, 0.0, None
metric_for_logging = config = None
@ -209,13 +219,21 @@ class AutoMLState:
config = config_w_resource.copy()
if "FLAML_sample_size" in config:
del config["FLAML_sample_size"]
time_left = self.time_budget - self.time_from_start
budget = (
time_left
None
if self.time_budget is None
else self.time_budget - self.time_from_start
if sample_size == self.data_size
else time_left / 2 * sample_size / self.data_size
else (self.time_budget - self.time_from_start)
/ 2
* sample_size
/ self.data_size
)
if _is_nlp_task(self.task):
self.fit_kwargs["X_val"] = self.X_val
self.fit_kwargs["y_val"] = self.y_val
(
trained_estimator,
val_loss,
@ -229,7 +247,9 @@ class AutoMLState:
self.y_val,
self.weight_val,
self.groups_val,
min(budget, self.train_time_limit),
self.train_time_limit
if budget is None
else min(budget, self.train_time_limit),
self.kf,
config,
self.task,
@ -242,6 +262,11 @@ class AutoMLState:
self.log_training_metric,
self.fit_kwargs,
)
if _is_nlp_task(self.task):
del self.fit_kwargs["X_val"]
del self.fit_kwargs["y_val"]
result = {
"pred_time": pred_time,
"wall_clock_time": time.time() - self._start_time_flag,
@ -253,7 +278,12 @@ class AutoMLState:
self.fit_kwargs["sample_weight"] = weight
return result
def _train_with_config(self, estimator, config_w_resource, sample_size=None):
def _train_with_config(
self,
estimator,
config_w_resource,
sample_size=None,
):
if not sample_size:
sample_size = config_w_resource.get(
"FLAML_sample_size", len(self.y_train_all)
@ -281,17 +311,52 @@ class AutoMLState:
if self.time_budget is None
else self.time_budget - self.time_from_start
)
estimator, train_time = train_estimator(
sampled_X_train,
sampled_y_train,
config,
self.task,
estimator,
self.n_jobs,
self.learner_classes.get(estimator),
budget,
self.fit_kwargs,
)
if self.resources_per_trial.get("gpu", 0) > 0:
def _trainable_function_wrapper(config: dict):
return_estimator, train_time = train_estimator(
X_train=sampled_X_train,
y_train=sampled_y_train,
config_dic=config,
task=self.task,
estimator_name=estimator,
n_jobs=self.n_jobs,
estimator_class=self.learner_classes.get(estimator),
budget=budget,
fit_kwargs=self.fit_kwargs,
)
return {"estimator": return_estimator, "train_time": train_time}
if estimator not in self.learner_classes:
self.learner_classes[estimator] = get_estimator_class(
self.task, estimator
)
analysis = tune.run(
_trainable_function_wrapper,
config=config_w_resource,
metric="train_time",
mode="min",
resources_per_trial=self.resources_per_trial,
num_samples=1,
use_ray=True,
)
result = list(analysis.results.values())[0]
estimator, train_time = result["estimator"], result["train_time"]
else:
estimator, train_time = train_estimator(
X_train=sampled_X_train,
y_train=sampled_y_train,
config_dic=config,
task=self.task,
estimator_name=estimator,
n_jobs=self.n_jobs,
estimator_class=self.learner_classes.get(estimator),
budget=budget,
fit_kwargs=self.fit_kwargs,
)
if sampled_weight is not None:
self.fit_kwargs["sample_weight"] = weight
return estimator, train_time
@ -423,7 +488,7 @@ class AutoML:
"""Time taken to find best model in seconds."""
return self.__dict__.get("_time_taken_best_iter")
def predict(self, X_test):
def predict(self, X_test: Union[np.array, DataFrame, List[str], List[List[str]]]):
"""Predict label from features.
Args:
@ -455,6 +520,32 @@ class AutoML:
"No estimator is trained. Please run fit with enough budget."
)
return None
if isinstance(X_test, List) and isinstance(X_test[0], List):
unzipped_X_test = [x for x in zip(*X_test)]
try:
X_test = DataFrame(
{
self._transformer._str_columns[idx]: unzipped_X_test[idx]
for idx in range(len(unzipped_X_test))
}
)
except IndexError:
raise IndexError(
"Test data contains more columns than training data, exiting"
)
elif isinstance(X_test, List):
try:
X_test = DataFrame(
{
self._transformer._str_columns[idx]: [X_test[idx]]
for idx in range(len(X_test))
}
)
except IndexError:
raise IndexError(
"Test data contains more columns than training data, exiting"
)
X_test = self._preprocess(X_test)
y_pred = estimator.predict(X_test)
if y_pred.ndim > 1 and isinstance(y_pred, np.ndarray):
@ -482,6 +573,7 @@ class AutoML:
return proba
def _preprocess(self, X):
if isinstance(X, int):
return X
if self._state.task == TS_FORECAST:
@ -503,6 +595,7 @@ class AutoML:
groups_val=None,
groups=None,
):
if X_train_all is not None and y_train_all is not None:
assert (
isinstance(X_train_all, np.ndarray)
@ -548,6 +641,33 @@ class AutoML:
y = dataframe[label]
else:
raise ValueError("either X_train+y_train or dataframe+label are required")
# check the validity of input dimensions under the nlp mode
if _is_nlp_task(self._state.task):
is_all_str = True
is_all_list = True
for column in X.columns:
assert X[column].dtype.name in (
"object",
"string",
), "If the task is an NLP task, X can only contain text columns"
for each_cell in X[column]:
if each_cell:
is_str = isinstance(each_cell, str)
is_list_of_int = isinstance(each_cell, list) and all(
isinstance(x, int) for x in each_cell
)
assert is_str or is_list_of_int, (
"Each column of the input must either be str (untokenized) "
"or a list of integers (tokenized)"
)
is_all_str &= is_str
is_all_list &= is_list_of_int
assert is_all_str or is_all_list, (
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
"or all columns of X are integer ids (tokenized)"
)
if issparse(X_train_all):
self._transformer = self._label_transformer = False
self._X_train_all, self._y_train_all = X, y
@ -607,6 +727,7 @@ class AutoML:
self._state.groups = groups
def _prepare_data(self, eval_method, split_ratio, n_splits):
X_val, y_val = self._state.X_val, self._state.y_val
if issparse(X_val):
X_val = X_val.tocsr()
@ -776,7 +897,7 @@ class AutoML:
if self._df
else np.concatenate([label_set, y_val])
)
elif self._state.task == "regression":
elif self._state.task in REGRESSION:
if "sample_weight" in self._state.fit_kwargs:
(
X_train,
@ -877,11 +998,11 @@ class AutoML:
config = record.config
estimator, _ = train_estimator(
None,
None,
config,
task,
estimator,
X_train=None,
y_train=None,
config_dic=config,
task=task,
estimator_name=estimator,
estimator_class=self._state.learner_classes.get(estimator),
)
return estimator
@ -901,6 +1022,7 @@ class AutoML:
split_type=None,
groups=None,
n_jobs=-1,
gpu_per_trial=0,
train_best=True,
train_full=False,
record_id=-1,
@ -946,6 +1068,7 @@ class AutoML:
for training data.
n_jobs: An integer of the number of threads for training. Use all
available resources when n_jobs == -1.
gpu_per_trial: A float of the number of gpus per trial. Only used by TransformersEstimator.
train_best: A boolean of whether to train the best config in the
time budget; if false, train the last config in the budget.
train_full: A boolean of whether to train on the full data. If true,
@ -963,6 +1086,7 @@ class AutoML:
self._state.task = TS_FORECAST
else:
self._state.task = task
self._state.fit_kwargs = fit_kwargs
self._validate_data(X_train, y_train, dataframe, label, groups=groups)
@ -1029,8 +1153,17 @@ class AutoML:
self._prepare_data(eval_method, split_ratio, n_splits)
self._state.time_budget = None
self._state.n_jobs = n_jobs
import os
self._state.resources_per_trial = (
{"cpu": os.cpu_count(), "gpu": gpu_per_trial}
if self._state.n_jobs < 0
else {"cpu": self._state.n_jobs, "gpu": gpu_per_trial}
)
self._trained_estimator = self._state._train_with_config(
best_estimator, best_config, sample_size
best_estimator,
best_config,
sample_size=sample_size,
)[0]
logger.info("retrain from log succeeded")
return training_duration
@ -1045,7 +1178,7 @@ class AutoML:
self._split_type = (
split_type or self._state.groups is None and "stratified" or "group"
)
elif self._state.task == "regression":
elif self._state.task in REGRESSION:
assert split_type in [None, "uniform", "time", "group"]
self._split_type = split_type or "uniform"
elif self._state.task == TS_FORECAST:
@ -1229,6 +1362,7 @@ class AutoML:
mem_res = self._mem_thres
def train(config: dict):
sample_size = config.get("FLAML_sample_size")
config = config.get("ml", config).copy()
if sample_size:
@ -1271,6 +1405,7 @@ class AutoML:
metric="auto",
task="classification",
n_jobs=-1,
gpu_per_trial=0,
log_file_name="flaml.log",
estimator_list="auto",
time_budget=60,
@ -1345,6 +1480,7 @@ class AutoML:
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast', 'rank'.
n_jobs: An integer of the number of threads for training.
gpu_per_trial: A float of the number of gpus per trial, only used by TransformersEstimator.
log_file_name: A string of the log file name.
estimator_list: A list of strings for estimator names, or 'auto'
e.g.,
@ -1454,12 +1590,14 @@ class AutoML:
the searched learners, such as sample_weight. Include period as
a key word argument for 'ts_forecast' task.
"""
self._state._start_time_flag = self._start_time_flag = time.time()
if task == FORECAST:
self._state.task = TS_FORECAST
else:
self._state.task = task
self._state.log_training_metric = log_training_metric
self._state.fit_kwargs = fit_kwargs
self._state.weight_val = sample_weight_val
@ -1502,6 +1640,10 @@ class AutoML:
self._auto_augment = auto_augment
self._min_sample_size = min_sample_size
self._prepare_data(eval_method, split_ratio, n_splits)
if _is_nlp_task(self._state.task):
self._state.fit_kwargs["metric"] = metric
self._sample = (
sample
and task != "rank"
@ -1549,6 +1691,8 @@ class AutoML:
estimator_list = ["arima", "sarimax"]
elif self._state.task == "rank":
estimator_list = ["lgbm", "xgboost"]
elif _is_nlp_task(self._state.task):
estimator_list = ["transformer"]
else:
try:
import catboost
@ -1587,6 +1731,13 @@ class AutoML:
self.split_ratio = split_ratio
self._state.save_model_history = model_history
self._state.n_jobs = n_jobs
import os
self._state.resources_per_trial = (
{"cpu": int(os.cpu_count() / n_concurrent_trials), "gpu": gpu_per_trial}
if self._state.n_jobs < 0
else {"cpu": self._state.n_jobs, "gpu": gpu_per_trial}
)
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_ray = use_ray or n_concurrent_trials > 1
@ -1701,9 +1852,7 @@ class AutoML:
time_budget_s=time_left,
)
search_alg = ConcurrencyLimiter(search_alg, self._n_concurrent_trials)
resources_per_trial = (
{"cpu": self._state.n_jobs} if self._state.n_jobs > 1 else None
)
resources_per_trial = self._state.resources_per_trial
analysis = ray.tune.run(
self.trainable,
search_alg=search_alg,

View File

@ -5,12 +5,17 @@
import numpy as np
from scipy.sparse import vstack, issparse
import pandas as pd
from pandas import DataFrame, Series
from .training_log import training_log_reader
from datetime import datetime
from typing import Dict, Union, List
CLASSIFICATION = ("binary", "multi", "classification")
SEQCLASSIFICATION = "seq-classification"
CLASSIFICATION = ("binary", "multi", "classification", SEQCLASSIFICATION)
SEQREGRESSION = "seq-regression"
REGRESSION = ("regression", SEQREGRESSION)
TS_FORECAST = "ts_forecast"
TS_TIMESTAMP_COL = "ds"
TS_VALUE_COL = "y"
@ -190,10 +195,10 @@ def get_output_from_log(filename, time_budget):
def concat(X1, X2):
"""concatenate two matrices vertically"""
if isinstance(X1, pd.DataFrame) or isinstance(X1, pd.Series):
if isinstance(X1, DataFrame) or isinstance(X1, Series):
df = pd.concat([X1, X2], sort=False)
df.reset_index(drop=True, inplace=True)
if isinstance(X1, pd.DataFrame):
if isinstance(X1, DataFrame):
cat_columns = X1.select_dtypes(include="category").columns
if len(cat_columns):
df[cat_columns] = df[cat_columns].astype("category")
@ -207,7 +212,7 @@ def concat(X1, X2):
class DataTransformer:
"""Transform input training data."""
def fit_transform(self, X, y, task):
def fit_transform(self, X: Union[DataFrame, np.array], y, task):
"""Fit transformer and process the input training data according to the task type.
Args:
@ -220,7 +225,19 @@ class DataTransformer:
X: Processed numpy array or pandas dataframe of training data.
y: Processed numpy array or pandas series of labels.
"""
if isinstance(X, pd.DataFrame):
from .nlp.utils import _is_nlp_task
if _is_nlp_task(task):
# if the mode is NLP, check the type of input, each column must be either string or
# ids (input ids, token type id, attention mask, etc.)
str_columns = []
for column in X.columns:
if isinstance(X[column].iloc[0], str):
str_columns.append(column)
if len(str_columns) > 0:
X[str_columns] = X[str_columns].astype("string")
self._str_columns = str_columns
elif isinstance(X, DataFrame):
X = X.copy()
n = X.shape[0]
cat_columns, num_columns, datetime_columns = [], [], []
@ -228,7 +245,7 @@ class DataTransformer:
if task == TS_FORECAST:
X = X.rename(columns={X.columns[0]: TS_TIMESTAMP_COL})
ds_col = X.pop(TS_TIMESTAMP_COL)
if isinstance(y, pd.Series):
if isinstance(y, Series):
y = y.rename(TS_VALUE_COL)
for column in X.columns:
# sklearn\utils\validation.py needs int/float values
@ -332,7 +349,7 @@ class DataTransformer:
self._task = task
return X, y
def transform(self, X):
def transform(self, X: Union[DataFrame, np.array]):
"""Process data using fit transformer.
Args:
@ -346,7 +363,15 @@ class DataTransformer:
y: Processed numpy array or pandas series of labels.
"""
X = X.copy()
if isinstance(X, pd.DataFrame):
from .nlp.utils import _is_nlp_task
if _is_nlp_task(self._task):
# if the mode is NLP, check the type of input, each column must be either string or
# ids (input ids, token type id, attention mask, etc.)
if len(self._str_columns) > 0:
X[self._str_columns] = X[self._str_columns].astype("string")
elif isinstance(X, DataFrame):
cat_columns, num_columns, datetime_columns = (
self._cat_columns,
self._num_columns,

View File

@ -30,6 +30,7 @@ from .model import (
Prophet,
ARIMA,
SARIMAX,
TransformersEstimator,
)
from .data import CLASSIFICATION, group_counts, TS_FORECAST, TS_VALUE_COL
import logging
@ -61,6 +62,8 @@ def get_estimator_class(task, estimator_name):
estimator_class = ARIMA
elif estimator_name == "sarimax":
estimator_class = SARIMAX
elif estimator_name == "transformer":
estimator_class = TransformersEstimator
else:
raise ValueError(
estimator_name + " is not a built-in learner. "
@ -415,7 +418,11 @@ def compute_estimator(
fit_kwargs={},
):
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
estimator = estimator_class(**config_dic, task=task, n_jobs=n_jobs)
estimator = estimator_class(
**config_dic,
task=task,
n_jobs=n_jobs,
)
if "holdout" == eval_method:
val_loss, metric_for_logging, train_time, pred_time = get_test_loss(
config_dic,
@ -450,9 +457,9 @@ def compute_estimator(
def train_estimator(
config_dic,
X_train,
y_train,
config_dic,
task,
estimator_name,
n_jobs=1,
@ -462,7 +469,11 @@ def train_estimator(
):
start_time = time.time()
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
estimator = estimator_class(**config_dic, task=task, n_jobs=n_jobs)
estimator = estimator_class(
**config_dic,
task=task,
n_jobs=n_jobs,
)
if X_train is not None:
train_time = estimator.fit(X_train, y_train, budget, **fit_kwargs)
else:

View File

@ -14,7 +14,6 @@ from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.sparse import issparse
import pandas as pd
import logging
from . import tune
from .data import (
@ -25,6 +24,9 @@ from .data import (
TS_VALUE_COL,
)
import pandas as pd
from pandas import DataFrame, Series
try:
import psutil
except ImportError:
@ -81,8 +83,8 @@ class BaseEstimator:
Args:
task: A string of the task type, one of
'binary', 'multi', 'regression', 'rank', 'forecast'
config: A dictionary containing the hyperparameter names
and 'n_jobs' as keys. n_jobs is the number of parallel threads.
config: A dictionary containing the hyperparameter names, 'n_jobs' as keys.
n_jobs is the number of parallel threads.
"""
self.params = self.config2params(config)
self.estimator_class = self._model = None
@ -273,7 +275,270 @@ class BaseEstimator:
Returns:
A dict that will be passed to self.estimator_class's constructor.
"""
return config.copy()
params = config.copy()
return params
class TransformersEstimator(BaseEstimator):
"""The class for fine-tuning language models, using huggingface transformers API."""
ITER_HP = "final_global_step"
def __init__(self, task="seq-classification", **config):
super().__init__(task, **config)
def _join(self, X_train, y_train):
y_train = DataFrame(y_train, columns=["label"], index=X_train.index)
train_df = X_train.join(y_train)
return train_df
@classmethod
def search_space(cls, **params):
import sys
return {
"learning_rate": {
"domain": tune.loguniform(lower=1e-6, upper=1e-3),
},
"num_train_epochs": {
"domain": tune.loguniform(lower=0.5, upper=10.0),
},
"per_device_train_batch_size": {
"domain": tune.choice([4, 8, 16, 32]),
},
"warmup_ratio": {
"domain": tune.uniform(lower=0.0, upper=0.3),
},
"weight_decay": {
"domain": tune.uniform(lower=0.0, upper=0.3),
},
"adam_epsilon": {
"domain": tune.loguniform(lower=1e-8, upper=1e-6),
},
"seed": {"domain": tune.choice(list(range(40, 45)))},
"final_global_step": {"domain": sys.maxsize},
}
def _init_hpo_args(self, automl_fit_kwargs: dict = None):
from .nlp.utils import HPOArgs
custom_hpo_args = HPOArgs()
for key, val in automl_fit_kwargs["custom_hpo_args"].items():
assert (
key in custom_hpo_args.__dict__
), "The specified key {} is not in the argument list of flaml.nlp.utils::HPOArgs".format(
key
)
setattr(custom_hpo_args, key, val)
self.custom_hpo_args = custom_hpo_args
def _preprocess(self, X, task, **kwargs):
from .nlp.utils import tokenize_text
if X.dtypes[0] == "string":
return tokenize_text(X, task, self.custom_hpo_args)
else:
return X
def fit(self, X_train: DataFrame, y_train: Series, budget=None, **kwargs):
# TODO: when self.param = {}, ie max_iter = 1, fix the bug
from transformers import EarlyStoppingCallback
this_params = self.params
class EarlyStoppingCallbackForAuto(EarlyStoppingCallback):
def on_train_begin(self, args, state, control, **callback_kwargs):
self.train_begin_time = time.time()
def on_step_begin(self, args, state, control, **callback_kwargs):
self.step_begin_time = time.time()
def on_step_end(self, args, state, control, **callback_kwargs):
if state.global_step == 1:
self.time_per_iter = time.time() - self.step_begin_time
if budget:
if (
time.time() + self.time_per_iter
> self.train_begin_time + budget
):
control.should_training_stop = True
control.should_save = True
control.should_evaluate = True
if state.global_step >= this_params[TransformersEstimator.ITER_HP]:
control.should_training_stop = True
return control
import transformers
from transformers import TrainingArguments
from transformers.trainer_utils import set_seed
from transformers import AutoTokenizer
from .nlp.utils import (
separate_config,
load_model,
get_num_labels,
compute_checkpoint_freq,
)
from .nlp.huggingface.trainer import TrainerForAuto
from datasets import Dataset
self._init_hpo_args(kwargs)
self._metric_name = kwargs["metric"]
X_val = kwargs.get("X_val")
y_val = kwargs.get("y_val")
X_train = self._preprocess(X_train, self._task, **kwargs)
train_dataset = Dataset.from_pandas(self._join(X_train, y_train))
if X_val is not None:
X_val = self._preprocess(X_val, self._task, **kwargs)
eval_dataset = Dataset.from_pandas(self._join(X_val, y_val))
else:
eval_dataset = None
tokenizer = AutoTokenizer.from_pretrained(
self.custom_hpo_args.model_path, use_fast=True
)
set_seed(self.params["seed"])
num_labels = get_num_labels(self._task, y_train)
training_args_config, per_model_config = separate_config(self.params)
this_model = load_model(
checkpoint_path=self.custom_hpo_args.model_path,
task=self._task,
num_labels=num_labels,
per_model_config=per_model_config,
)
ckpt_freq = compute_checkpoint_freq(
train_data_size=len(X_train),
custom_hpo_args=self.custom_hpo_args,
num_train_epochs=self.params["num_train_epochs"],
batch_size=self.params["per_device_train_batch_size"],
)
if transformers.__version__.startswith("3"):
training_args = TrainingArguments(
output_dir=self.custom_hpo_args.output_dir,
do_train=True,
do_eval=True,
eval_steps=ckpt_freq,
evaluate_during_training=True,
save_steps=ckpt_freq,
save_total_limit=0,
fp16=self.custom_hpo_args.fp16,
load_best_model_at_end=True,
**training_args_config,
)
else:
from transformers import IntervalStrategy
training_args = TrainingArguments(
output_dir=self.custom_hpo_args.output_dir,
do_train=True,
do_eval=True,
per_device_eval_batch_size=1,
eval_steps=ckpt_freq,
evaluation_strategy=IntervalStrategy.STEPS,
save_steps=ckpt_freq,
save_total_limit=0,
fp16=self.custom_hpo_args.fp16,
load_best_model_at_end=True,
**training_args_config,
)
def _model_init():
return load_model(
checkpoint_path=self.custom_hpo_args.model_path,
task=self._task,
num_labels=num_labels,
per_model_config=per_model_config,
)
trainer = TrainerForAuto(
model=this_model,
args=training_args,
model_init=_model_init,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
compute_metrics=self._compute_metrics_by_dataset_name,
callbacks=[EarlyStoppingCallbackForAuto],
)
trainer.train()
if eval_dataset is not None:
# if validation data is non empty, select the best checkpoint and save the final global step to self.params
self.params[self.ITER_HP] = trainer.state.global_step
if trainer.state.global_step > max(trainer.ckpt_to_global_step.values()):
trainer.evaluate()
self._checkpoint_path = self._select_checkpoint(
trainer.ckpt_to_metric, trainer.ckpt_to_global_step
)
else:
# if validation dataset is empty, save the last checkpoint
self._checkpoint_path = self._save_last_checkpoint(trainer)
self._kwargs = kwargs
self._num_labels = num_labels
self._per_model_config = per_model_config
def _save_last_checkpoint(self, trainer):
this_ckpt = trainer.save_state()
self.params[self.ITER_HP] = trainer.state.global_step
return this_ckpt
def _select_checkpoint(self, ckpt_to_score, ckpt_to_global_step):
best_ckpt, best_score = min(
ckpt_to_score.items(), key=lambda x: x[1][self._metric_name]
)
best_ckpt_global_step = ckpt_to_global_step[best_ckpt]
self.params[self.ITER_HP] = best_ckpt_global_step
return best_ckpt
def _compute_metrics_by_dataset_name(self, eval_pred):
from .ml import sklearn_metric_loss_score
from .data import SEQREGRESSION
predictions, labels = eval_pred
predictions = (
np.squeeze(predictions)
if self._task == SEQREGRESSION
else np.argmax(predictions, axis=1)
)
return {
self._metric_name: sklearn_metric_loss_score(
metric_name=self._metric_name, y_predict=predictions, y_true=labels
)
}
def predict(self, X_test):
from datasets import Dataset
from .nlp.utils import load_model
from transformers import TrainingArguments
from .nlp.huggingface.trainer import TrainerForAuto
if X_test.dtypes[0] == "string":
X_test = self._preprocess(X_test, self._task, **self._kwargs)
test_dataset = Dataset.from_pandas(X_test)
best_model = load_model(
checkpoint_path=self._checkpoint_path,
task=self._task,
num_labels=self._num_labels,
per_model_config=self._per_model_config,
)
training_args = TrainingArguments(
per_device_eval_batch_size=1,
output_dir=self.custom_hpo_args.output_dir,
)
test_trainer = TrainerForAuto(model=best_model, args=training_args)
predictions = test_trainer.predict(test_dataset)
return np.argmax(predictions.predictions, axis=1)
class SKLearnEstimator(BaseEstimator):
@ -283,14 +548,14 @@ class SKLearnEstimator(BaseEstimator):
super().__init__(task, **config)
def _preprocess(self, X):
if isinstance(X, pd.DataFrame):
if isinstance(X, DataFrame):
cat_columns = X.select_dtypes(include=["category"]).columns
if not cat_columns.empty:
X = X.copy()
X[cat_columns] = X[cat_columns].apply(lambda x: x.cat.codes)
elif isinstance(X, np.ndarray) and X.dtype.kind not in "buif":
# numpy array is not of numeric dtype
X = pd.DataFrame(X)
X = DataFrame(X)
for col in X.columns:
if isinstance(X[col][0], str):
X[col] = X[col].astype("category").cat.codes
@ -383,14 +648,14 @@ class LGBMEstimator(BaseEstimator):
def _preprocess(self, X):
if (
not isinstance(X, pd.DataFrame)
not isinstance(X, DataFrame)
and issparse(X)
and np.issubdtype(X.dtype, np.integer)
):
X = X.astype(float)
elif isinstance(X, np.ndarray) and X.dtype.kind not in "buif":
# numpy array is not of numeric dtype
X = pd.DataFrame(X)
X = DataFrame(X)
for col in X.columns:
if isinstance(X[col][0], str):
X[col] = X[col].astype("category").cat.codes
@ -665,6 +930,7 @@ class XGBoostSklearnEstimator(SKLearnEstimator, LGBMEstimator):
return XGBoostEstimator.cost_relative2lgbm()
def config2params(cls, config: dict) -> dict:
# TODO: test
params = config.copy()
params["max_depth"] = 0
params["grow_policy"] = params.get("grow_policy", "lossguide")
@ -859,7 +1125,7 @@ class CatBoostEstimator(BaseEstimator):
return 15
def _preprocess(self, X):
if isinstance(X, pd.DataFrame):
if isinstance(X, DataFrame):
cat_columns = X.select_dtypes(include=["category"]).columns
if not cat_columns.empty:
X = X.copy()
@ -873,7 +1139,7 @@ class CatBoostEstimator(BaseEstimator):
)
elif isinstance(X, np.ndarray) and X.dtype.kind not in "buif":
# numpy array is not of numeric dtype
X = pd.DataFrame(X)
X = DataFrame(X)
for col in X.columns:
if isinstance(X[col][0], str):
X[col] = X[col].astype("category").cat.codes
@ -914,7 +1180,7 @@ class CatBoostEstimator(BaseEstimator):
deadline = start_time + budget if budget else np.inf
train_dir = f"catboost_{str(start_time)}"
X_train = self._preprocess(X_train)
if isinstance(X_train, pd.DataFrame):
if isinstance(X_train, DataFrame):
cat_features = list(X_train.select_dtypes(include="category").columns)
else:
cat_features = []
@ -1009,14 +1275,14 @@ class KNeighborsEstimator(BaseEstimator):
self.estimator_class = KNeighborsRegressor
def _preprocess(self, X):
if isinstance(X, pd.DataFrame):
if isinstance(X, DataFrame):
cat_columns = X.select_dtypes(["category"]).columns
if X.shape[1] == len(cat_columns):
raise ValueError("kneighbor requires at least one numeric feature")
X = X.drop(cat_columns, axis=1)
elif isinstance(X, np.ndarray) and X.dtype.kind not in "buif":
# drop categocial columns if any
X = pd.DataFrame(X)
X = DataFrame(X)
cat_columns = []
for col in X.columns:
if isinstance(X[col][0], str):
@ -1060,7 +1326,7 @@ class Prophet(SKLearnEstimator):
"Dataframe for training ts_forecast model must have column"
f' "{TS_TIMESTAMP_COL}" with the dates in X_train.'
)
y_train = pd.DataFrame(y_train, columns=[TS_VALUE_COL])
y_train = DataFrame(y_train, columns=[TS_VALUE_COL])
train_df = X_train.join(y_train)
return train_df
@ -1167,7 +1433,7 @@ class ARIMA(Prophet):
if self._model is not None:
if isinstance(X_test, int):
forecast = self._model.forecast(steps=X_test)
elif isinstance(X_test, pd.DataFrame):
elif isinstance(X_test, DataFrame):
first_col = X_test.pop(TS_TIMESTAMP_COL)
X_test.insert(0, TS_TIMESTAMP_COL, first_col)
start = X_test.iloc[0, 0]
@ -1183,7 +1449,7 @@ class ARIMA(Prophet):
forecast = self._model.predict(start=start, end=end)
else:
raise ValueError(
"X_test needs to be either a pd.Dataframe with dates as the first column"
"X_test needs to be either a pandas Dataframe with dates as the first column"
" or an int number of periods for predict()."
)
return forecast

76
flaml/nlp/README.md Normal file
View File

@ -0,0 +1,76 @@
# Hyperparameter Optimization for Huggingface Transformers
Fine-tuning pre-trained language models based on the transformers library.
An example:
```python
from flaml import AutoML
import pandas as pd
train_dataset = pd.read_csv("data/input/train.tsv", delimiter="\t", quoting=3)
dev_dataset = pd.read_csv("data/input/dev.tsv", delimiter="\t", quoting=3)
test_dataset = pd.read_csv("data/input/test.tsv", delimiter="\t", quoting=3)
custom_sent_keys = ["#1 String", "#2 String"]
label_key = "Quality"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
X_test = test_dataset[custom_sent_keys]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0, # use a value larger than 0 for GPU training
"max_iter": 10,
"time_budget": 300,
"task": "seq-classification",
"metric": "accuracy",
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 1,
}
automl.fit(
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
)
automl.predict(X_test)
```
The current use cases that are supported:
1. A simplified version of fine-tuning the GLUE dataset using HuggingFace;
2. For selecting better search space for fine-tuning the GLUE dataset;
3. Use the search algorithms in flaml for more efficient fine-tuning of HuggingFace.
The use cases that can be supported in future:
1. HPO fine-tuning for text generation;
2. HPO fine-tuning for question answering.
## Troubleshooting fine-tuning HPO for pre-trained language models
To reproduce the results for our ACL2021 paper:
* [An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. ACL-IJCNLP 2021.
```bibtex
@inproceedings{liu2021hpo,
title={An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models},
author={Xueqing Liu and Chi Wang},
year={2021},
booktitle={ACL-IJCNLP},
}
```
Please refer to the following jupyter notebook: [Troubleshooting HPO for fine-tuning pre-trained language models](https://github.com/microsoft/FLAML/blob/main/notebook/research/acl2021.ipynb)

1
flaml/nlp/__init__.py Normal file
View File

@ -0,0 +1 @@
from .utils import HPOArgs

View File

View File

@ -0,0 +1,58 @@
from collections import OrderedDict
import transformers
if transformers.__version__.startswith("3"):
from transformers.modeling_electra import ElectraClassificationHead
from transformers.modeling_roberta import RobertaClassificationHead
else:
from transformers.models.electra.modeling_electra import ElectraClassificationHead
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
MODEL_CLASSIFICATION_HEAD_MAPPING = OrderedDict(
[
("electra", ElectraClassificationHead),
("roberta", RobertaClassificationHead),
]
)
class AutoSeqClassificationHead:
"""
This is a class for getting classification head class based on the name of the LM
instantiated as one of the ClassificationHead classes of the library when
created with the `~flaml.nlp.huggingface.AutoSeqClassificationHead.from_model_type_and_config` method.
This class cannot be instantiated directly using ``__init__()`` (throws an error).
"""
def __init__(self):
raise EnvironmentError(
"AutoSeqClassificationHead is designed to be instantiated "
"using the `AutoSeqClassificationHead.from_model_type_and_config(cls, model_type, config)` methods."
)
@classmethod
def from_model_type_and_config(cls, model_type, config):
"""
Instantiate one of the classification head classes from the mode_type and model configuration.
Args:
model_type:
A string, which desribes the model type, e.g., "electra"
config (:class:`~transformers.PretrainedConfig`):
The huggingface class of the model's configuration:
Examples::
>>> from transformers import AutoConfig
>>> model_config = AutoConfig.from_pretrained("google/electra-base-discriminator")
>>> AutoSeqClassificationHead.from_model_type_and_config("electra", model_config)
"""
if model_type in MODEL_CLASSIFICATION_HEAD_MAPPING.keys():
return MODEL_CLASSIFICATION_HEAD_MAPPING[model_type](config)
raise ValueError(
"Unrecognized configuration class {} for class {}.\n"
"Model type should be one of {}.".format(
config.__class__, cls.__name__, ", ".join(MODEL_CLASSIFICATION_HEAD_MAPPING.keys())
)
)

View File

@ -0,0 +1,60 @@
import os
try:
from transformers import Trainer as TFTrainer
except ImportError:
TFTrainer = object
class TrainerForAuto(TFTrainer):
def evaluate(self, eval_dataset=None):
"""
Overriding transformers.Trainer.evaluate by saving state with save_state
Args:
eval_dataset:
the dataset to be evaluated
"""
if self.eval_dataset is not None:
eval_dataloader = self.get_eval_dataloader(self.eval_dataset)
output = self.prediction_loop(eval_dataloader, description="Evaluation")
self.log(output.metrics)
ckpt_dir = self.save_state()
for key in list(output.metrics.keys()):
if key.startswith("eval_"):
output.metrics[key[5:]] = output.metrics.pop(key)
if hasattr(self, "ckpt_to_global_step"):
self.ckpt_to_metric[ckpt_dir] = output.metrics
self.ckpt_to_global_step[ckpt_dir] = self.state.global_step
else:
self.ckpt_to_global_step = {ckpt_dir: self.state.global_step}
self.ckpt_to_metric = {ckpt_dir: output.metrics}
def save_state(self):
"""
Overriding transformers.Trainer.save_state. It is only through saving
the states can best_trial.get_best_checkpoint return a non-empty value.
"""
import torch
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from ray import tune
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
self.args.output_dir = checkpoint_dir
# This is the directory name that Huggingface requires.
output_dir = os.path.join(
self.args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}",
)
self.save_model(output_dir)
torch.save(
self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
)
torch.save(
self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
)
return output_dir

216
flaml/nlp/utils.py Normal file
View File

@ -0,0 +1,216 @@
import argparse
from dataclasses import dataclass, field
from ..data import SEQCLASSIFICATION, SEQREGRESSION
def _is_nlp_task(task):
if task in [SEQCLASSIFICATION, SEQREGRESSION]:
return True
else:
return False
global tokenized_column_names
def tokenize_text(X, task, custom_hpo_task):
from ..data import SEQCLASSIFICATION
if task in (SEQCLASSIFICATION, SEQREGRESSION):
return tokenize_text_seqclassification(X, custom_hpo_task)
def tokenize_text_seqclassification(X, custom_hpo_args):
from transformers import AutoTokenizer
import pandas
global tokenized_column_names
this_tokenizer = AutoTokenizer.from_pretrained(
custom_hpo_args.model_path, use_fast=True
)
d = X.apply(
lambda x: tokenize_glue(x, this_tokenizer, custom_hpo_args),
axis=1,
result_type="expand",
)
X_tokenized = pandas.DataFrame(columns=tokenized_column_names)
X_tokenized[tokenized_column_names] = d
return X_tokenized
def tokenize_glue(this_row, this_tokenizer, custom_hpo_args):
global tokenized_column_names
assert (
"max_seq_length" in custom_hpo_args.__dict__
), "max_seq_length must be provided for glue"
tokenized_example = this_tokenizer(
*tuple(this_row),
padding="max_length",
max_length=custom_hpo_args.max_seq_length,
truncation=True,
)
tokenized_column_names = sorted(tokenized_example.keys())
return [tokenized_example[x] for x in tokenized_column_names]
def separate_config(config):
from transformers import TrainingArguments
training_args_config = {}
per_model_config = {}
for key, val in config.items():
if key in TrainingArguments.__dict__:
training_args_config[key] = val
else:
per_model_config[key] = val
return training_args_config, per_model_config
def get_num_labels(task, y_train):
if task == SEQREGRESSION:
return 1
elif task == SEQCLASSIFICATION:
return len(set(y_train))
def load_model(checkpoint_path, task, num_labels, per_model_config=None):
from transformers import AutoConfig
from .huggingface.switch_head_auto import (
AutoSeqClassificationHead,
MODEL_CLASSIFICATION_HEAD_MAPPING,
)
this_model_type = AutoConfig.from_pretrained(checkpoint_path).model_type
this_vocab_size = AutoConfig.from_pretrained(checkpoint_path).vocab_size
def get_this_model():
from transformers import AutoModelForSequenceClassification
return AutoModelForSequenceClassification.from_pretrained(
checkpoint_path, config=model_config
)
def is_pretrained_model_in_classification_head_list(model_type):
return model_type in MODEL_CLASSIFICATION_HEAD_MAPPING
def _set_model_config(checkpoint_path):
if per_model_config and len(per_model_config) > 0:
model_config = AutoConfig.from_pretrained(
checkpoint_path,
num_labels=model_config_num_labels,
**per_model_config,
)
else:
model_config = AutoConfig.from_pretrained(
checkpoint_path, num_labels=model_config_num_labels
)
return model_config
if task == SEQCLASSIFICATION:
num_labels_old = AutoConfig.from_pretrained(checkpoint_path).num_labels
if is_pretrained_model_in_classification_head_list(this_model_type):
model_config_num_labels = num_labels_old
else:
model_config_num_labels = num_labels
model_config = _set_model_config(checkpoint_path)
if is_pretrained_model_in_classification_head_list(this_model_type):
if num_labels != num_labels_old:
this_model = get_this_model()
model_config.num_labels = num_labels
this_model.num_labels = num_labels
this_model.classifier = (
AutoSeqClassificationHead.from_model_type_and_config(
this_model_type, model_config
)
)
else:
this_model = get_this_model()
else:
this_model = get_this_model()
this_model.resize_token_embeddings(this_vocab_size)
return this_model
elif task == SEQREGRESSION:
model_config_num_labels = 1
model_config = _set_model_config(checkpoint_path)
this_model = get_this_model()
return this_model
def compute_checkpoint_freq(
train_data_size,
custom_hpo_args,
num_train_epochs,
batch_size,
):
ckpt_step_freq = (
int(
min(num_train_epochs, 1)
* train_data_size
/ batch_size
/ custom_hpo_args.ckpt_per_epoch
)
+ 1
)
return ckpt_step_freq
@dataclass
class HPOArgs:
"""The HPO setting
Args:
output_dir (:obj:`str`):
data root directory for outputing the log, etc.
model_path (:obj:`str`, `optional`, defaults to :obj:`facebook/muppet-roberta-base`):
A string, the path of the language model file, either a path from huggingface
model card huggingface.co/models, or a local path for the model
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
A bool, whether to use FP16
max_seq_length (:obj:`int`, `optional`, defaults to :obj:`128`):
An integer, the max length of the sequence
ckpt_per_epoch (:obj:`int`, `optional`, defaults to :obj:`1`):
An integer, the number of checkpoints per epoch
"""
output_dir: str = field(
default="data/output/", metadata={"help": "data dir", "required": True}
)
model_path: str = field(
default="facebook/muppet-roberta-base",
metadata={"help": "model path model for HPO"},
)
fp16: bool = field(default=True, metadata={"help": "whether to use the FP16 mode"})
max_seq_length: int = field(default=128, metadata={"help": "max seq length"})
ckpt_per_epoch: int = field(default=1, metadata={"help": "checkpoint per epoch"})
@staticmethod
def load_args():
from dataclasses import fields
arg_parser = argparse.ArgumentParser()
for each_field in fields(HPOArgs):
print(each_field)
arg_parser.add_argument(
"--" + each_field.name,
type=each_field.type,
help=each_field.metadata["help"],
required=each_field.metadata["required"]
if "required" in each_field.metadata
else False,
choices=each_field.metadata["choices"]
if "choices" in each_field.metadata
else None,
default=each_field.default,
)
console_args, unknown = arg_parser.parse_known_args()
return console_args

View File

@ -0,0 +1,809 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) 2020-2021. All rights reserved.\n",
"\n",
"Licensed under the MIT License.\n",
"\n",
"# Troubleshooting HPO for fine-tuning pre-trained language models\n",
"\n",
"## 1. Introduction\n",
"\n",
"In this notebook, we demonstrate a procedure for troubleshooting HPO failure in fine-tuning pre-trained language models (introduced in the following paper):\n",
"\n",
"*[An Empirical Study on Hyperparameter Optimization for Fine-Tuning Pre-trained Language Models](https://arxiv.org/abs/2106.09204). Xueqing Liu, Chi Wang. ACL-IJCNLP 2021*\n",
"\n",
"Notes:\n",
"\n",
"*In this notebook, we only run each experiment 1 time for simplicity, which is different from the paper (3 times). To reproduce the paper's result, please run 3 repetitions and take the average scores.\n",
"\n",
"*Running this notebook takes about one hour.\n",
"\n",
"FLAML requires `Python>=3.6`. To run this notebook example, please install flaml with the `notebook` and `nlp` options:\n",
"\n",
"```bash\n",
"pip install flaml[nlp]==0.7.1 # in higher version of flaml, the API for nlp tasks changed\n",
"```\n",
"\n",
"Our paper was developed under transformers version 3.4.0. We uninstall and reinstall transformers==3.4.0:\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"!pip install flaml[nlp]==0.7.1 # in higher version of flaml, the API for nlp tasks changed\n",
"!pip install transformers==3.4.0\n",
"from flaml.nlp import AutoTransformers\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Initial Experimental Study\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load dataset \n",
"\n",
"Load the dataset using AutoTransformer.prepare_data. In this notebook, we use the Microsoft Research Paraphrasing Corpus (MRPC) dataset and the Electra model as an example:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"console_args has no attribute pretrained_model_size, continue\n",
"console_args has no attribute dataset_subdataset_name, continue\n",
"console_args has no attribute algo_mode, continue\n",
"console_args has no attribute space_mode, continue\n",
"console_args has no attribute search_alg_args_mode, continue\n",
"console_args has no attribute algo_name, continue\n",
"console_args has no attribute pruner, continue\n",
"console_args has no attribute resplit_mode, continue\n",
"console_args has no attribute rep_id, continue\n",
"console_args has no attribute seed_data, continue\n",
"console_args has no attribute seed_transformers, continue\n",
"console_args has no attribute learning_rate, continue\n",
"console_args has no attribute weight_decay, continue\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Reusing dataset glue (/home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-6a78e5c95406457c.arrow\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-e8d0f3e04c3b4588.arrow\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-4b0966b394994163.arrow\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-6a78e5c95406457c.arrow\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-e8d0f3e04c3b4588.arrow\n",
"Loading cached processed dataset at /home/xliu127/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-4b0966b394994163.arrow\n"
]
}
],
"source": [
"autohf = AutoTransformers()\n",
"preparedata_setting = {\n",
" \"dataset_subdataset_name\": \"glue:mrpc\",\n",
" \"pretrained_model_size\": \"google/electra-base-discriminator:base\",\n",
" \"data_root_path\": \"data/\",\n",
" \"max_seq_length\": 128,\n",
" }\n",
"autohf.prepare_data(**preparedata_setting)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Running grid search\n",
"\n",
"First, we run grid search using Electra. By specifying `algo_mode=\"grid\"`, AutoTransformers will run the grid search algorithm. By specifying `space_mode=\"grid\"`, AutoTransformers will use the default grid search configuration recommended by the Electra paper:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Memory usage on this node: 14.2/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/250.73 GiB heap, 0.0/76.9 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 67d99_00002 with accuracy=0.7254901960784313 and parameters={'learning_rate': 0.0001, 'weight_decay': 0.0, 'adam_epsilon': 1e-06, 'warmup_ratio': 0.1, 'per_device_train_batch_size': 32, 'hidden_dropout_prob': 0.1, 'attention_probs_dropout_prob': 0.1, 'num_train_epochs': 0.5, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=grid_spa=grid_arg=dft_alg=grid_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 4/4 (4 TERMINATED)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-06-16 10:45:35,071\tINFO tune.py:450 -- Total run time: 106.56 seconds (106.41 seconds for the tuning loop).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total running time: 106.57789206504822 seconds\n"
]
}
],
"source": [
"import transformers\n",
"autohf_settings = {\n",
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
" \"num_samples\": 1,\n",
" \"time_budget\": 100000, # unlimited time budget\n",
" \"ckpt_per_epoch\": 5,\n",
" \"fp16\": True,\n",
" \"algo_mode\": \"grid\", # set the search algorithm to grid search\n",
" \"space_mode\": \"grid\", # set the search space to the recommended grid space\n",
" \"transformers_verbose\": transformers.logging.ERROR\n",
" }\n",
"validation_metric, analysis = autohf.fit(**autohf_settings)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Get the time for running grid search: "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"grid search for glue_mrpc took 106.57789206504822 seconds\n"
]
}
],
"source": [
"GST = autohf.last_run_duration\n",
"print(\"grid search for {} took {} seconds\".format(autohf.jobid_config.get_jobid_full_data_name(), GST))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After the HPO run finishes, generate the predictions and save it as a .zip file to be submitted to the glue website. Here we will need the library AzureUtils which is for storing the output information (e.g., analysis log, .zip file) locally and uploading the output to an azure blob container (e.g., if multiple jobs are executed in a cluster). If the azure key and container information is not specified, the output information will only be saved locally. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"remove_columns_ is deprecated and will be removed in the next major version of datasets. Use the dataset.remove_columns method instead.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cleaning the existing label column from test data\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" </style>\n",
" \n",
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [432/432 00:34]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"JobID(dat=['glue'], subdat='mrpc', mod='grid', spa='grid', arg='dft', alg='grid', pru='None', pre_full='google/electra-base-discriminator', pre='electra', presz='base', spt='ori', rep=0, sddt=43, sdhf=42, var1=None, var2=None)\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"{'eval_accuracy': 0.7254901960784313, 'eval_f1': 0.8276923076923076, 'eval_loss': 0.516851007938385}\n"
]
}
],
"source": [
"predictions, test_metric = autohf.predict()\n",
"from flaml.nlp import AzureUtils\n",
"\n",
"print(autohf.jobid_config)\n",
"\n",
"azure_utils = AzureUtils(root_log_path=\"logs_test/\", autohf=autohf)\n",
"azure_utils.write_autohf_output(valid_metric=validation_metric,\n",
" predictions=predictions,\n",
" duration=GST)\n",
"print(validation_metric)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The validation F1/accuracy we got was 92.4/89.5. After the above steps, you will find a .zip file for the predictions under data/result/. Submit the .zip file to the glue website. The test F1/accuracy we got was 90.4/86.7. As an example, we only run the experiment one time, but in general, we should run the experiment multiple repetitions and report the averaged validation and test accuracy."
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"### Running Random Search\n",
"\n",
"Next, we run random search with the same time budget as grid search:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def tune_hpo(time_budget, this_hpo_space):\n",
" autohf_settings = {\n",
" \"resources_per_trial\": {\"gpu\": 1, \"cpu\": 1},\n",
" \"num_samples\": -1,\n",
" \"time_budget\": time_budget,\n",
" \"ckpt_per_epoch\": 5,\n",
" \"fp16\": True,\n",
" \"algo_mode\": \"hpo\", # set the search algorithm mode to hpo\n",
" \"algo_name\": \"rs\",\n",
" \"space_mode\": \"cus\", # customized search space (this_hpo_space)\n",
" \"hpo_space\": this_hpo_space,\n",
" \"transformers_verbose\": transformers.logging.ERROR\n",
" }\n",
" validation_metric, analysis = autohf.fit(**autohf_settings)\n",
" predictions, test_metric = autohf.predict()\n",
" azure_utils = AzureUtils(root_log_path=\"logs_test/\", autohf=autohf)\n",
" azure_utils.write_autohf_output(valid_metric=validation_metric,\n",
" predictions=predictions,\n",
" duration=GST)\n",
" print(validation_metric)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Memory usage on this node: 30.1/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.51 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: c67b4_00003 with accuracy=0.7303921568627451 and parameters={'learning_rate': 4.030097060410288e-05, 'warmup_ratio': 0.06084844859190755, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.15742692948967135, 'attention_probs_dropout_prob': 0.08638900372842316, 'hidden_dropout_prob': 0.058245828039608386, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 8/infinite (8 TERMINATED)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-06-16 10:48:21,624\tINFO tune.py:450 -- Total run time: 114.32 seconds (109.41 seconds for the tuning loop).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total running time: 114.35665488243103 seconds\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" </style>\n",
" \n",
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [432/432 00:33]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Your output will not be synced to azure because azure key and container name are not specified\n",
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"{'eval_accuracy': 0.7328431372549019, 'eval_f1': 0.8320493066255777, 'eval_loss': 0.5411379933357239}\n"
]
}
],
"source": [
"hpo_space_full = {\n",
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
" \"warmup_ratio\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
" \"num_train_epochs\": [3],\n",
" \"per_device_train_batch_size\": [16, 32, 64],\n",
" \"weight_decay\": {\"l\": 0.0, \"u\": 0.3, \"space\": \"linear\"},\n",
" \"attention_probs_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
" \"hidden_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
" }\n",
"\n",
"tune_hpo(GST, hpo_space_full)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The validation F1/accuracy we got was 93.5/90.9. Similarly, we can submit the .zip file to the glue website. The test F1/accuaracy we got was 81.6/70.2. "
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"## 3. Troubleshooting HPO Failures\n",
"\n",
"Since the validation accuracy is larger than grid search while the test accuracy is smaller, HPO has overfitting. We reduce the search space:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Memory usage on this node: 26.5/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.51 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 234d8_00003 with accuracy=0.7475490196078431 and parameters={'learning_rate': 0.00011454435497690623, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.06370173320348284, 'attention_probs_dropout_prob': 0.03636499344142013, 'hidden_dropout_prob': 0.03668090197068676, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 6/infinite (6 TERMINATED)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-06-16 10:51:34,598\tINFO tune.py:450 -- Total run time: 151.57 seconds (136.77 seconds for the tuning loop).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total running time: 151.59901237487793 seconds\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" </style>\n",
" \n",
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [432/432 00:33]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Your output will not be synced to azure because azure key and container name are not specified\n",
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"{'eval_accuracy': 0.7475490196078431, 'eval_f1': 0.8325203252032519, 'eval_loss': 0.5056071877479553}\n"
]
}
],
"source": [
"hpo_space_fixwr = {\n",
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
" \"warmup_ratio\": [0.1],\n",
" \"num_train_epochs\": [3],\n",
" \"per_device_train_batch_size\": [16, 32, 64],\n",
" \"weight_decay\": {\"l\": 0.0, \"u\": 0.3, \"space\": \"linear\"},\n",
" \"attention_probs_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
" \"hidden_dropout_prob\": {\"l\": 0, \"u\": 0.2, \"space\": \"linear\"},\n",
" }\n",
"tune_hpo(GST, hpo_space_fixwr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The validation F1/accuracy we got was 92.6/89.7, the test F1/accuracy was 85.9/78.7, therefore overfitting still exists and we further reduce the space: "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Memory usage on this node: 29.6/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.46 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: 96a67_00003 with accuracy=0.7107843137254902 and parameters={'learning_rate': 7.862589064613256e-05, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 32, 'weight_decay': 0.0, 'attention_probs_dropout_prob': 0.1, 'hidden_dropout_prob': 0.1, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 6/infinite (6 TERMINATED)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-06-16 10:54:14,542\tINFO tune.py:450 -- Total run time: 117.99 seconds (112.99 seconds for the tuning loop).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total running time: 118.01927375793457 seconds\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" </style>\n",
" \n",
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [432/432 00:33]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Your output will not be synced to azure because azure key and container name are not specified\n",
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"{'eval_accuracy': 0.7181372549019608, 'eval_f1': 0.8174962292609351, 'eval_loss': 0.5494586229324341}\n"
]
}
],
"source": [
"hpo_space_min = {\n",
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
" \"warmup_ratio\": [0.1],\n",
" \"num_train_epochs\": [3],\n",
" \"per_device_train_batch_size\": [16, 32, 64],\n",
" \"weight_decay\": [0.0],\n",
" \"attention_probs_dropout_prob\": [0.1],\n",
" \"hidden_dropout_prob\": [0.1],\n",
" }\n",
"tune_hpo(GST, hpo_space_min)"
]
},
{
"cell_type": "markdown",
"metadata": {
"pycharm": {
"name": "#%% md\n"
}
},
"source": [
"The validation F1/accuracy we got was 90.4/86.7, test F1/accuracy was 83.0/73.0. Since the validation accuracy is below grid search, we increase the budget to 4 * GST:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"== Status ==<br>Memory usage on this node: 26.2/376.6 GiB<br>Using FIFO scheduling algorithm.<br>Resources requested: 0/96 CPUs, 0/4 GPUs, 0.0/247.46 GiB heap, 0.0/75.93 GiB objects (0/1.0 accelerator_type:V100)<br>Current best trial: f5d31_00005 with accuracy=0.7352941176470589 and parameters={'learning_rate': 3.856175093679045e-05, 'warmup_ratio': 0.1, 'num_train_epochs': 0.5, 'per_device_train_batch_size': 16, 'weight_decay': 0.0, 'attention_probs_dropout_prob': 0.1, 'hidden_dropout_prob': 0.1, 'seed': 42}<br>Result logdir: /data/xliu127/projects/hyperopt/FLAML/notebook/data/checkpoint/dat=glue_subdat=mrpc_mod=hpo_spa=cus_arg=dft_alg=rs_pru=None_pre=electra_presz=base_spt=ori_rep=0_sddt=43_sdhf=42_var1=None_var2=None/ray_result<br>Number of trials: 16/infinite (16 TERMINATED)<br><br>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-06-16 11:03:23,308\tINFO tune.py:450 -- Total run time: 507.09 seconds (445.79 seconds for the tuning loop).\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total running time: 507.15925645828247 seconds\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" </style>\n",
" \n",
" <progress value='432' max='432' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [432/432 00:34]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Your output will not be synced to azure because azure key and container name are not specified\n",
"The path for saving the prediction .zip file is not specified, setting to data/ by default\n",
"Your output will not be synced to azure because azure key and container name are not specified\n",
"{'eval_accuracy': 0.7401960784313726, 'eval_f1': 0.8333333333333334, 'eval_loss': 0.5303606986999512}\n"
]
}
],
"source": [
"hpo_space_min = {\n",
" \"learning_rate\": {\"l\": 3e-5, \"u\": 1.5e-4, \"space\": \"log\"},\n",
" \"warmup_ratio\": [0.1],\n",
" \"num_train_epochs\": [3],\n",
" \"per_device_train_batch_size\": [32],\n",
" \"weight_decay\": [0.0],\n",
" \"attention_probs_dropout_prob\": [0.1],\n",
" \"hidden_dropout_prob\": [0.1],\n",
" }\n",
"tune_hpo(4 * GST, hpo_space_min)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The validation F1/accuracy we got was 93.5/91.1, where the accuracy outperforms grid search. The test F1/accuracy was 90.1/86.1. As a result, random search with 4*GST and the minimum space overfits. We stop the troubleshooting process because the search space cannot be further reduced."
]
}
],
"metadata": {
"interpreter": {
"hash": "bfcd9a6a9254a5e160761a1fd7a9e444f011592c6770d9f4180dde058a9df5dd"
},
"kernelspec": {
"display_name": "Python 3.7.7 64-bit ('flaml': conda)",
"name": "python3"
},
"language_info": {
"name": "python",
"version": ""
}
},
"nbformat": 4,
"nbformat_minor": 1
}

View File

@ -54,6 +54,10 @@ setuptools.setup(
"openml",
"statsmodels>=0.12.2",
"psutil==5.8.0",
"dataclasses",
"transformers",
"datasets==1.4.1",
"torch",
],
"catboost": ["catboost>=0.26"],
"blendsearch": ["optuna==2.8.0"],
@ -70,6 +74,7 @@ setuptools.setup(
"vw": [
"vowpalwabbit",
],
"nlp": ["transformers", "datasets==1.4.1", "torch"],
"ts_forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
"forecast": ["prophet>=1.0.1", "statsmodels>=0.12.2"],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],

8
test/load_args.py Normal file
View File

@ -0,0 +1,8 @@
def test_load_args_sub():
from flaml.nlp.utils import HPOArgs
HPOArgs.load_args()
if __name__ == "__main__":
test_load_args_sub()

View File

@ -19,7 +19,7 @@ import torch.optim as optim
from nni.utils import merge_parameter
from torchvision import datasets, transforms
logger = logging.getLogger('mnist_AutoML')
logger = logging.getLogger("mnist_AutoML")
class Net(nn.Module):
@ -44,7 +44,7 @@ class Net(nn.Module):
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
if (args["batch_num"] is not None) and batch_idx >= args["batch_num"]:
break
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
@ -52,10 +52,16 @@ def train(args, model, device, train_loader, optimizer, epoch):
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
if batch_idx % args["log_interval"] == 0:
logger.info(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(train_loader.dataset),
100.0 * batch_idx / len(train_loader),
loss.item(),
)
)
def test(args, model, device, test_loader):
@ -67,95 +73,140 @@ def test(args, model, device, test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
accuracy = 100.0 * correct / len(test_loader.dataset)
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
logger.info(
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
test_loss, correct, len(test_loader.dataset), accuracy
)
)
return accuracy
def main(args):
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
use_cuda = not args["no_cuda"] and torch.cuda.is_available()
torch.manual_seed(args['seed'])
torch.manual_seed(args["seed"])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
data_dir = args['data_dir']
data_dir = args["data_dir"]
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
datasets.MNIST(
data_dir,
train=True,
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=args["batch_size"],
shuffle=True,
**kwargs
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)
datasets.MNIST(
data_dir,
train=False,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
),
),
batch_size=1000,
shuffle=True,
**kwargs
)
hidden_size = args['hidden_size']
hidden_size = args["hidden_size"]
model = Net(hidden_size=hidden_size).to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])
optimizer = optim.SGD(model.parameters(), lr=args["lr"], momentum=args["momentum"])
for epoch in range(1, args['epochs'] + 1):
for epoch in range(1, args["epochs"] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
logger.debug("test accuracy %g", test_acc)
logger.debug("Pipe send intermediate result done.")
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
logger.debug("Final result is %g", test_acc)
logger.debug("Send final result done.")
def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='./data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument("--data_dir", type=str, default="./data", help="data directory")
parser.add_argument(
"--batch_size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)",
)
parser.add_argument("--batch_num", type=int, default=None)
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument(
"--hidden_size",
type=int,
default=512,
metavar="N",
help="hidden layer size (default: 512)",
)
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)",
)
parser.add_argument(
"--momentum",
type=float,
default=0.5,
metavar="M",
help="SGD momentum (default: 0.5)",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
metavar="N",
help="number of epochs to train (default: 10)",
)
parser.add_argument(
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
)
parser.add_argument(
"--no_cuda", action="store_true", default=False, help="disables CUDA training"
)
parser.add_argument(
"--log_interval",
type=int,
default=1000,
metavar="N",
help="how many batches to wait before logging training status",
)
args, _ = parser.parse_known_args()
return args
if __name__ == '__main__':
if __name__ == "__main__":
try:
# get parameters form tuner
tuner_params = nni.get_next_parameter()

124
test/test_autohf.py Normal file
View File

@ -0,0 +1,124 @@
def test_hf_data():
try:
import ray
except ImportError:
return
from flaml import AutoML
from datasets import load_dataset
train_dataset = (
load_dataset("glue", "mrpc", split="train[:1%]").to_pandas().iloc[0:4]
)
dev_dataset = (
load_dataset("glue", "mrpc", split="train[1%:2%]").to_pandas().iloc[0:4]
)
test_dataset = (
load_dataset("glue", "mrpc", split="test[1%:2%]").to_pandas().iloc[0:4]
)
custom_sent_keys = ["sentence1", "sentence2"]
label_key = "label"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
X_test = test_dataset[custom_sent_keys]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0,
"max_iter": 3,
"time_budget": 20,
"task": "seq-classification",
"metric": "accuracy",
"model_history": True,
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 5,
"fp16": False,
}
automl.fit(
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
)
automl = AutoML()
automl.retrain_from_log(
log_file_name="flaml.log",
X_train=X_train,
y_train=y_train,
train_full=True,
record_id=0,
**automl_settings
)
automl.predict(X_test)
automl.predict(["test test", "test test"])
automl.predict(
[
["test test", "test test"],
["test test", "test test"],
["test test", "test test"],
]
)
def _test_custom_data():
try:
import ray
except ImportError:
return
from flaml import AutoML
import pandas as pd
train_dataset = pd.read_csv("data/input/train.tsv", delimiter="\t", quoting=3)
dev_dataset = pd.read_csv("data/input/dev.tsv", delimiter="\t", quoting=3)
test_dataset = pd.read_csv("data/input/test.tsv", delimiter="\t", quoting=3)
custom_sent_keys = ["#1 String", "#2 String"]
label_key = "Quality"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
X_test = test_dataset[custom_sent_keys]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0,
"max_iter": 10,
"time_budget": 300,
"task": "seq-classification",
"metric": "accuracy",
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 1,
}
automl.fit(
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
)
automl.predict(X_test)
automl.predict(["test test"])
automl.predict(
[
["test test", "test test"],
["test test", "test test"],
["test test", "test test"],
]
)

View File

@ -0,0 +1,42 @@
def test_classification_head():
try:
import ray
except ImportError:
return
from flaml import AutoML
from datasets import load_dataset
train_dataset = load_dataset("emotion", split="train[:1%]").to_pandas().iloc[0:10]
dev_dataset = load_dataset("emotion", split="train[1%:2%]").to_pandas().iloc[0:10]
custom_sent_keys = ["text"]
label_key = "label"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0,
"max_iter": 3,
"time_budget": 20,
"task": "seq-classification",
"metric": "accuracy",
"model_history": True,
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 5,
"fp16": False,
}
automl.fit(
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
)

39
test/test_autohf_cv.py Normal file
View File

@ -0,0 +1,39 @@
def test_cv():
try:
import ray
except ImportError:
return
from flaml import AutoML
from datasets import load_dataset
train_dataset = (
load_dataset("glue", "mrpc", split="train[:1%]").to_pandas().iloc[0:4]
)
custom_sent_keys = ["sentence1", "sentence2"]
label_key = "label"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0,
"max_iter": 3,
"time_budget": 20,
"task": "seq-classification",
"metric": "accuracy",
"n_splits": 3,
"model_history": True,
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)

View File

@ -0,0 +1,7 @@
def test_load_args():
import subprocess
import sys
subprocess.call(
[sys.executable, "load_args.py", "--output_dir", "data/"], shell=True
)

View File

@ -0,0 +1,46 @@
def test_regression():
try:
import ray
except ImportError:
return
from flaml import AutoML
from datasets import load_dataset
train_dataset = (
load_dataset("glue", "stsb", split="train[:1%]").to_pandas().iloc[0:4]
)
dev_dataset = (
load_dataset("glue", "stsb", split="train[1%:2%]").to_pandas().iloc[0:4]
)
custom_sent_keys = ["sentence1", "sentence2"]
label_key = "label"
X_train = train_dataset[custom_sent_keys]
y_train = train_dataset[label_key]
X_val = dev_dataset[custom_sent_keys]
y_val = dev_dataset[label_key]
automl = AutoML()
automl_settings = {
"gpu_per_trial": 0,
"max_iter": 3,
"time_budget": 20,
"task": "seq-regression",
"metric": "rmse",
"model_history": True,
}
automl_settings["custom_hpo_args"] = {
"model_path": "google/electra-small-discriminator",
"output_dir": "data/output/",
"ckpt_per_epoch": 5,
"fp16": False,
}
automl.fit(
X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, **automl_settings
)

View File

@ -13,21 +13,27 @@ import string
import os
import openml
VW_DS_DIR = 'test/data/'
VW_DS_DIR = "test/data/"
NS_LIST = list(string.ascii_lowercase) + list(string.ascii_uppercase)
logger = logging.getLogger(__name__)
def oml_to_vw_w_grouping(X, y, ds_dir, fname, orginal_dim, group_num,
grouping_method='sequential'):
def oml_to_vw_w_grouping(
X, y, ds_dir, fname, orginal_dim, group_num, grouping_method="sequential"
):
# split all_indexes into # group_num of groups
max_size_per_group = int(np.ceil(orginal_dim / float(group_num)))
# sequential grouping
if grouping_method == 'sequential':
if grouping_method == "sequential":
group_indexes = [] # lists of lists
for i in range(group_num):
indexes = [ind for ind in range(i * max_size_per_group,
min((i + 1) * max_size_per_group, orginal_dim))]
indexes = [
ind
for ind in range(
i * max_size_per_group,
min((i + 1) * max_size_per_group, orginal_dim),
)
]
if len(indexes) > 0:
group_indexes.append(indexes)
print(group_indexes)
@ -36,34 +42,39 @@ def oml_to_vw_w_grouping(X, y, ds_dir, fname, orginal_dim, group_num,
if group_indexes:
if not os.path.exists(ds_dir):
os.makedirs(ds_dir)
with open(os.path.join(ds_dir, fname), 'w') as f:
with open(os.path.join(ds_dir, fname), "w") as f:
if isinstance(X, pd.DataFrame):
raise NotImplementedError
elif isinstance(X, np.ndarray):
for i in range(len(X)):
NS_content = []
for zz in range(len(group_indexes)):
ns_features = ' '.join('{}:{:.6f}'.format(ind, X[i][ind]
) for ind in group_indexes[zz])
ns_features = " ".join(
"{}:{:.6f}".format(ind, X[i][ind])
for ind in group_indexes[zz]
)
NS_content.append(ns_features)
ns_line = '{} |{}'.format(str(y[i]), '|'.join(
'{} {}'.format(NS_LIST[j], NS_content[j]
) for j in range(len(group_indexes))))
ns_line = "{} |{}".format(
str(y[i]),
"|".join(
"{} {}".format(NS_LIST[j], NS_content[j])
for j in range(len(group_indexes))
),
)
f.write(ns_line)
f.write('\n')
f.write("\n")
elif isinstance(X, scipy.sparse.csr_matrix):
print('NotImplementedError for sparse data')
print("NotImplementedError for sparse data")
NotImplementedError
def save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression):
""" convert openml dataset to vw example and save to file
"""
print('is_regression', is_regression)
"""convert openml dataset to vw example and save to file"""
print("is_regression", is_regression)
if is_regression:
fname = 'ds_{}_{}_{}.vw'.format(did, max_ns_num, 0)
print('dataset size', X.shape[0], X.shape[1])
print('saving data', did, ds_dir, fname)
fname = "ds_{}_{}_{}.vw".format(did, max_ns_num, 0)
print("dataset size", X.shape[0], X.shape[1])
print("saving data", did, ds_dir, fname)
dim = X.shape[1]
oml_to_vw_w_grouping(X, y, ds_dir, fname, dim, group_num=max_ns_num)
else:
@ -84,62 +95,75 @@ def shuffle_data(X, y, seed):
def get_oml_to_vw(did, max_ns_num, ds_dir=VW_DS_DIR):
success = False
print('-----getting oml dataset-------', did)
print("-----getting oml dataset-------", did)
ds = openml.datasets.get_dataset(did)
target_attribute = ds.default_target_attribute
# if target_attribute is None and did in OML_target_attribute_dict:
# target_attribute = OML_target_attribute_dict[did]
print('target=ds.default_target_attribute', target_attribute)
data = ds.get_data(target=target_attribute, dataset_format='array')
print("target=ds.default_target_attribute", target_attribute)
data = ds.get_data(target=target_attribute, dataset_format="array")
X, y = data[0], data[1] # return X: pd DataFrame, y: pd series
import scipy
if scipy.sparse.issparse(X):
X = scipy.sparse.csr_matrix.toarray(X)
print('is sparse matrix')
print("is sparse matrix")
if data and isinstance(X, np.ndarray):
print('-----converting oml to vw and and saving oml dataset-------')
print("-----converting oml to vw and and saving oml dataset-------")
save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression=True)
success = True
else:
print('---failed to convert/save oml dataset to vw!!!----')
print("---failed to convert/save oml dataset to vw!!!----")
try:
X, y = data[0], data[1] # return X: pd DataFrame, y: pd series
if data and isinstance(X, np.ndarray):
print('-----converting oml to vw and and saving oml dataset-------')
print("-----converting oml to vw and and saving oml dataset-------")
save_vw_dataset_w_ns(X, y, did, ds_dir, max_ns_num, is_regression=True)
success = True
else:
print('---failed to convert/save oml dataset to vw!!!----')
print("---failed to convert/save oml dataset to vw!!!----")
except ValueError:
print('-------------failed to get oml dataset!!!', did)
print("-------------failed to get oml dataset!!!", did)
return success
def load_vw_dataset(did, ds_dir, is_regression, max_ns_num):
import os
if is_regression:
# the second field specifies the largest number of namespaces using.
fname = 'ds_{}_{}_{}.vw'.format(did, max_ns_num, 0)
fname = "ds_{}_{}_{}.vw".format(did, max_ns_num, 0)
vw_dataset_file = os.path.join(ds_dir, fname)
# if file does not exist, generate and save the datasets
if not os.path.exists(vw_dataset_file) or os.stat(vw_dataset_file).st_size < 1000:
if (
not os.path.exists(vw_dataset_file)
or os.stat(vw_dataset_file).st_size < 1000
):
get_oml_to_vw(did, max_ns_num)
print(ds_dir, vw_dataset_file)
if not os.path.exists(ds_dir):
os.makedirs(ds_dir)
with open(os.path.join(ds_dir, fname), 'r') as f:
with open(os.path.join(ds_dir, fname), "r") as f:
vw_content = f.read().splitlines()
print(type(vw_content), len(vw_content))
return vw_content
def get_data(iter_num=None, dataset_id=None, vw_format=True,
max_ns_num=10, shuffle=False, use_log=True, dataset_type='regression'):
logging.info('generating data')
def get_data(
iter_num=None,
dataset_id=None,
vw_format=True,
max_ns_num=10,
shuffle=False,
use_log=True,
dataset_type="regression",
):
logging.info("generating data")
LOG_TRANSFORMATION_THRESHOLD = 100
# get data from simulation
import random
vw_examples = None
data_id = int(dataset_id)
# loading oml dataset
@ -147,12 +171,13 @@ def get_data(iter_num=None, dataset_id=None, vw_format=True,
# Y = data.Y
if vw_format:
# vw_examples = data.vw_examples
vw_examples = load_vw_dataset(did=data_id, ds_dir=VW_DS_DIR, is_regression=True,
max_ns_num=max_ns_num)
vw_examples = load_vw_dataset(
did=data_id, ds_dir=VW_DS_DIR, is_regression=True, max_ns_num=max_ns_num
)
Y = []
for i, e in enumerate(vw_examples):
Y.append(float(e.split('|')[0]))
logger.debug('first data %s', vw_examples[0])
Y.append(float(e.split("|")[0]))
logger.debug("first data %s", vw_examples[0])
# do data shuffling or log transformation for oml data when needed
if shuffle:
random.seed(54321)
@ -165,55 +190,67 @@ def get_data(iter_num=None, dataset_id=None, vw_format=True,
if use_log and max((max_y - min_y), max_y) >= LOG_TRANSFORMATION_THRESHOLD:
log_vw_examples = []
for v in vw_examples:
org_y = v.split('|')[0]
y = float(v.split('|')[0])
org_y = v.split("|")[0]
y = float(v.split("|")[0])
# shift y to ensure all y are positive
if min_y <= 0:
y = y + abs(min_y) + 1
log_y = np.log(y)
log_vw = v.replace(org_y + '|', str(log_y) + ' |')
log_vw = v.replace(org_y + "|", str(log_y) + " |")
log_vw_examples.append(log_vw)
logger.info('log_vw_examples %s', log_vw_examples[0:2])
logger.info("log_vw_examples %s", log_vw_examples[0:2])
if log_vw_examples:
return log_vw_examples
return vw_examples, Y
class VowpalWabbitNamesspaceTuningProblem:
def __init__(self, max_iter_num, dataset_id, ns_num, **kwargs):
use_log = kwargs.get('use_log', True),
shuffle = kwargs.get('shuffle', False)
vw_format = kwargs.get('vw_format', True)
print('dataset_id', dataset_id)
self.vw_examples, self.Y = get_data(max_iter_num, dataset_id=dataset_id,
vw_format=vw_format, max_ns_num=ns_num,
shuffle=shuffle, use_log=use_log
)
use_log = (kwargs.get("use_log", True),)
shuffle = kwargs.get("shuffle", False)
vw_format = kwargs.get("vw_format", True)
print("dataset_id", dataset_id)
self.vw_examples, self.Y = get_data(
max_iter_num,
dataset_id=dataset_id,
vw_format=vw_format,
max_ns_num=ns_num,
shuffle=shuffle,
use_log=use_log,
)
self.max_iter_num = min(max_iter_num, len(self.Y))
self._problem_info = {'max_iter_num': self.max_iter_num,
'dataset_id': dataset_id,
'ns_num': ns_num,
}
self._problem_info = {
"max_iter_num": self.max_iter_num,
"dataset_id": dataset_id,
"ns_num": ns_num,
}
self._problem_info.update(kwargs)
self._fixed_hp_config = kwargs.get('fixed_hp_config', {})
self.namespace_feature_dim = AutoVW.get_ns_feature_dim_from_vw_example(self.vw_examples[0])
self._fixed_hp_config = kwargs.get("fixed_hp_config", {})
self.namespace_feature_dim = AutoVW.get_ns_feature_dim_from_vw_example(
self.vw_examples[0]
)
self._raw_namespaces = list(self.namespace_feature_dim.keys())
self._setup_search()
def _setup_search(self):
self._search_space = self._fixed_hp_config.copy()
self._init_config = self._fixed_hp_config.copy()
search_space = {'interactions':
polynomial_expansion_set(
init_monomials=set(self._raw_namespaces),
highest_poly_order=len(self._raw_namespaces),
allow_self_inter=False),
}
init_config = {'interactions': set()}
search_space = {
"interactions": polynomial_expansion_set(
init_monomials=set(self._raw_namespaces),
highest_poly_order=len(self._raw_namespaces),
allow_self_inter=False,
),
}
init_config = {"interactions": set()}
self._search_space.update(search_space)
self._init_config.update(init_config)
logger.info('search space %s %s %s', self._search_space, self._init_config, self._fixed_hp_config)
logger.info(
"search space %s %s %s",
self._search_space,
self._init_config,
self._fixed_hp_config,
)
@property
def init_config(self):
@ -225,7 +262,6 @@ class VowpalWabbitNamesspaceTuningProblem:
class VowpalWabbitNamesspaceLRTuningProblem(VowpalWabbitNamesspaceTuningProblem):
def __init__(self, max_iter_num, dataset_id, ns_num, **kwargs):
super().__init__(max_iter_num, dataset_id, ns_num, **kwargs)
self._setup_search()
@ -233,29 +269,34 @@ class VowpalWabbitNamesspaceLRTuningProblem(VowpalWabbitNamesspaceTuningProblem)
def _setup_search(self):
self._search_space = self._fixed_hp_config.copy()
self._init_config = self._fixed_hp_config.copy()
search_space = {'interactions':
polynomial_expansion_set(
init_monomials=set(self._raw_namespaces),
highest_poly_order=len(self._raw_namespaces),
allow_self_inter=False),
'learning_rate': loguniform(lower=2e-10, upper=1.0)
}
init_config = {'interactions': set(), 'learning_rate': 0.5}
search_space = {
"interactions": polynomial_expansion_set(
init_monomials=set(self._raw_namespaces),
highest_poly_order=len(self._raw_namespaces),
allow_self_inter=False,
),
"learning_rate": loguniform(lower=2e-10, upper=1.0),
}
init_config = {"interactions": set(), "learning_rate": 0.5}
self._search_space.update(search_space)
self._init_config.update(init_config)
logger.info('search space %s %s %s', self._search_space, self._init_config, self._fixed_hp_config)
logger.info(
"search space %s %s %s",
self._search_space,
self._init_config,
self._fixed_hp_config,
)
def get_y_from_vw_example(vw_example):
""" get y from a vw_example. this works for regression dataset
"""
return float(vw_example.split('|')[0])
"""get y from a vw_example. this works for regression dataset"""
return float(vw_example.split("|")[0])
def get_loss(y_pred, y_true, loss_func='squared'):
if 'squared' in loss_func:
def get_loss(y_pred, y_true, loss_func="squared"):
if "squared" in loss_func:
loss = mean_squared_error([y_pred], [y_true])
elif 'absolute' in loss_func:
elif "absolute" in loss_func:
loss = mean_absolute_error([y_pred], [y_true])
else:
loss = None
@ -263,7 +304,7 @@ def get_loss(y_pred, y_true, loss_func='squared'):
return loss
def online_learning_loop(iter_num, vw_examples, vw_alg, loss_func, method_name=''):
def online_learning_loop(iter_num, vw_examples, vw_alg, loss_func, method_name=""):
"""Implements the online learning loop.
Args:
iter_num (int): The total number of iterations
@ -276,7 +317,7 @@ def online_learning_loop(iter_num, vw_examples, vw_alg, loss_func, method_name='
cumulative_loss_list (list): the list of cumulative loss from each iteration.
It is returned for the convenience of visualization.
"""
print('rerunning exp....', len(vw_examples), iter_num)
print("rerunning exp....", len(vw_examples), iter_num)
loss_list = []
y_predict_list = []
for i in range(iter_num):
@ -294,23 +335,29 @@ def online_learning_loop(iter_num, vw_examples, vw_alg, loss_func, method_name='
return loss_list
def get_vw_tuning_problem(tuning_hp='NamesapceInteraction'):
online_vw_exp_setting = {"max_live_model_num": 5,
"fixed_hp_config": {'alg': 'supervised', 'loss_function': 'squared'},
"ns_num": 10,
"max_iter_num": 10000,
}
def get_vw_tuning_problem(tuning_hp="NamesapceInteraction"):
online_vw_exp_setting = {
"max_live_model_num": 5,
"fixed_hp_config": {"alg": "supervised", "loss_function": "squared"},
"ns_num": 10,
"max_iter_num": 10000,
}
# construct openml problem setting based on basic experiment setting
vw_oml_problem_args = {"max_iter_num": online_vw_exp_setting['max_iter_num'],
"dataset_id": '42183',
"ns_num": online_vw_exp_setting['ns_num'],
"fixed_hp_config": online_vw_exp_setting['fixed_hp_config'],
}
if tuning_hp == 'NamesapceInteraction':
vw_online_aml_problem = VowpalWabbitNamesspaceTuningProblem(**vw_oml_problem_args)
elif tuning_hp == 'NamesapceInteraction+LearningRate':
vw_online_aml_problem = VowpalWabbitNamesspaceLRTuningProblem(**vw_oml_problem_args)
vw_oml_problem_args = {
"max_iter_num": online_vw_exp_setting["max_iter_num"],
"dataset_id": "42183",
"ns_num": online_vw_exp_setting["ns_num"],
"fixed_hp_config": online_vw_exp_setting["fixed_hp_config"],
}
if tuning_hp == "NamesapceInteraction":
vw_online_aml_problem = VowpalWabbitNamesspaceTuningProblem(
**vw_oml_problem_args
)
elif tuning_hp == "NamesapceInteraction+LearningRate":
vw_online_aml_problem = VowpalWabbitNamesspaceLRTuningProblem(
**vw_oml_problem_args
)
else:
NotImplementedError
@ -318,48 +365,68 @@ def get_vw_tuning_problem(tuning_hp='NamesapceInteraction'):
class TestAutoVW(unittest.TestCase):
def test_vw_oml_problem_and_vanilla_vw(self):
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem()
vanilla_vw = pyvw.vw(**vw_oml_problem_args["fixed_hp_config"])
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
vanilla_vw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
)
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
cumulative_loss_list = online_learning_loop(
vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
vanilla_vw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get(
"loss_function", "squared"
),
)
print(
"final average loss:", sum(cumulative_loss_list) / len(cumulative_loss_list)
)
def test_supervised_vw_tune_namespace(self):
# basic experiment setting
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem()
autovw = AutoVW(max_live_model_num=5,
search_space=vw_online_aml_problem.search_space,
init_config=vw_online_aml_problem.init_config,
min_resource_lease='auto',
random_seed=2345)
autovw = AutoVW(
max_live_model_num=5,
search_space=vw_online_aml_problem.search_space,
init_config=vw_online_aml_problem.init_config,
min_resource_lease="auto",
random_seed=2345,
)
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
autovw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
)
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
cumulative_loss_list = online_learning_loop(
vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
autovw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get(
"loss_function", "squared"
),
)
print(
"final average loss:", sum(cumulative_loss_list) / len(cumulative_loss_list)
)
def test_supervised_vw_tune_namespace_learningrate(self):
# basic experiment setting
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem(tuning_hp='NamesapceInteraction+LearningRate')
autovw = AutoVW(max_live_model_num=5,
search_space=vw_online_aml_problem.search_space,
init_config=vw_online_aml_problem.init_config,
min_resource_lease='auto',
random_seed=2345)
vw_oml_problem_args, vw_online_aml_problem = get_vw_tuning_problem(
tuning_hp="NamesapceInteraction+LearningRate"
)
autovw = AutoVW(
max_live_model_num=5,
search_space=vw_online_aml_problem.search_space,
init_config=vw_online_aml_problem.init_config,
min_resource_lease="auto",
random_seed=2345,
)
cumulative_loss_list = online_learning_loop(vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
autovw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get("loss_function", "squared"),
)
print('final average loss:', sum(cumulative_loss_list) / len(cumulative_loss_list))
cumulative_loss_list = online_learning_loop(
vw_online_aml_problem.max_iter_num,
vw_online_aml_problem.vw_examples,
autovw,
loss_func=vw_oml_problem_args["fixed_hp_config"].get(
"loss_function", "squared"
),
)
print(
"final average loss:", sum(cumulative_loss_list) / len(cumulative_loss_list)
)
def test_bandit_vw_tune_namespace(self):
pass

View File

@ -129,7 +129,9 @@ def load_multi_dataset():
import pandas as pd
# pd.set_option("display.max_rows", None, "display.max_columns", None)
df = pd.read_csv("https://raw.githubusercontent.com/srivatsan88/YouTubeLI/master/dataset/nyc_energy_consumption.csv")
df = pd.read_csv(
"https://raw.githubusercontent.com/srivatsan88/YouTubeLI/master/dataset/nyc_energy_consumption.csv"
)
# preprocessing data
df["timeStamp"] = pd.to_datetime(df["timeStamp"])
df = df.set_index("timeStamp")
@ -150,7 +152,9 @@ def test_multivariate_forecast_num(budget=5):
split_idx = num_samples - time_horizon
train_df = df[:split_idx]
test_df = df[split_idx:]
X_test = test_df[["timeStamp", "temp", "precip"]] # test dataframe must contain values for the regressors / multivariate variables
X_test = test_df[
["timeStamp", "temp", "precip"]
] # test dataframe must contain values for the regressors / multivariate variables
y_test = test_df["demand"]
# return
automl = AutoML()
@ -161,9 +165,9 @@ def test_multivariate_forecast_num(budget=5):
"log_file_name": "test/energy_forecast_numerical.log", # flaml log file
"eval_method": "holdout",
"log_type": "all",
"label": "demand"
"label": "demand",
}
'''The main flaml automl API'''
"""The main flaml automl API"""
try:
import prophet
@ -197,8 +201,13 @@ def test_multivariate_forecast_num(budget=5):
print("mape", "=", sklearn_metric_loss_score("mape", y_pred, y_test))
from flaml.data import get_output_from_log
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = \
get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.prune_attr)
@ -253,7 +262,9 @@ def load_multi_dataset_cat(time_horizon):
return 0
df["season"] = df["timeStamp"].apply(season)
df["above_monthly_avg"] = df.apply(lambda x: above_monthly_avg(x["timeStamp"], x["temp"]), axis=1)
df["above_monthly_avg"] = df.apply(
lambda x: above_monthly_avg(x["timeStamp"], x["temp"]), axis=1
)
# split data into train and test
num_samples = df.shape[0]
@ -270,7 +281,9 @@ def test_multivariate_forecast_cat(budget=5):
time_horizon = 180
train_df, test_df = load_multi_dataset_cat(time_horizon)
print(train_df)
X_test = test_df[["timeStamp", "season", "above_monthly_avg"]] # test dataframe must contain values for the regressors / multivariate variables
X_test = test_df[
["timeStamp", "season", "above_monthly_avg"]
] # test dataframe must contain values for the regressors / multivariate variables
y_test = test_df["demand"]
automl = AutoML()
settings = {
@ -280,9 +293,9 @@ def test_multivariate_forecast_cat(budget=5):
"log_file_name": "test/energy_forecast_numerical.log", # flaml log file
"eval_method": "holdout",
"log_type": "all",
"label": "demand"
"label": "demand",
}
'''The main flaml automl API'''
"""The main flaml automl API"""
try:
import prophet
@ -319,8 +332,13 @@ def test_multivariate_forecast_cat(budget=5):
print("mae", "=", sklearn_metric_loss_score("mae", y_pred, y_test))
from flaml.data import get_output_from_log
time_history, best_valid_loss_history, valid_loss_history, config_history, metric_history = \
get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.prune_attr)

View File

@ -124,4 +124,5 @@ def test_rank():
if __name__ == "__main__":
unittest.main()
# unittest.main()
test_groups()

View File

@ -3,9 +3,8 @@ import flaml
class TestVersion(unittest.TestCase):
def test_version(self):
self.assertTrue(hasattr(flaml, '__version__'))
self.assertTrue(hasattr(flaml, "__version__"))
self.assertTrue(len(flaml.__version__) > 0)

View File

@ -12,38 +12,37 @@ dataset = "credit-g"
class XGBoost2D(XGBoostSklearnEstimator):
@classmethod
def search_space(cls, data_size, task):
upper = min(32768, int(data_size))
return {
'n_estimators': {
'domain': tune.lograndint(lower=4, upper=upper),
'init_value': 4,
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=upper),
"init_value": 4,
},
'max_leaves': {
'domain': tune.lograndint(lower=4, upper=upper),
'init_value': 4,
"max_leaves": {
"domain": tune.lograndint(lower=4, upper=upper),
"init_value": 4,
},
}
def _test_simple(method=None, size_ratio=1.0):
automl = AutoML()
automl.add_learner(learner_name='XGBoost2D',
learner_class=XGBoost2D)
automl.add_learner(learner_name="XGBoost2D", learner_class=XGBoost2D)
X, y = fetch_openml(name=dataset, return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33,
random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.33, random_state=42
)
final_size = int(len(y_train) * size_ratio)
X_train = X_train[:final_size]
y_train = y_train[:final_size]
automl_settings = {
"estimator_list": ['XGBoost2D'],
"estimator_list": ["XGBoost2D"],
# "metric": 'accuracy',
"task": 'classification',
"task": "classification",
"log_file_name": f"test/xgboost2d_{dataset}_{method}_{final_size}.log",
# "model_history": True,
# "log_training_metric": True,

View File

@ -2,11 +2,12 @@ import time
def evaluation_fn(step, width, height):
return (0.1 + width * step / 100)**(-1) + height * 0.1
return (0.1 + width * step / 100) ** (-1) + height * 0.1
def easy_objective(config):
from ray import tune
# Hyperparameters
width, height = config["width"], config["height"]
@ -25,7 +26,7 @@ def test_blendsearch_tune(smoke_test=True):
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.suggest.flaml import BlendSearch
except ImportError:
print('ray[tune] is not installed, skipping test')
print("ray[tune] is not installed, skipping test")
return
import numpy as np
@ -46,7 +47,8 @@ def test_blendsearch_tune(smoke_test=True):
# This is an ignored parameter.
"activation": tune.choice(["relu", "tanh"]),
"test4": np.zeros((3, 1)),
})
},
)
print("Best hyperparameters found were: ", analysis.best_config)

View File

@ -1,13 +1,14 @@
'''Require: pip install torchvision ray flaml[blendsearch]
'''
"""Require: pip install torchvision ray flaml[blendsearch]
"""
import os
import time
import numpy as np
import logging
logger = logging.getLogger(__name__)
os.makedirs('logs', exist_ok=True)
logger.addHandler(logging.FileHandler('logs/tune_pytorch_cifar10.log'))
os.makedirs("logs", exist_ok=True)
logger.addHandler(logging.FileHandler("logs/tune_pytorch_cifar10.log"))
logger.setLevel(logging.INFO)
@ -22,7 +23,6 @@ try:
# __net_begin__
class Net(nn.Module):
def __init__(self, l1=120, l2=84):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
@ -40,6 +40,7 @@ try:
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# __net_end__
except ImportError:
print("skip test_pytorch because torchvision cannot be imported.")
@ -47,18 +48,21 @@ except ImportError:
# __load_data_begin__
def load_data(data_dir="test/data"):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
trainset = torchvision.datasets.CIFAR10(
root=data_dir, train=True, download=True, transform=transform)
root=data_dir, train=True, download=True, transform=transform
)
testset = torchvision.datasets.CIFAR10(
root=data_dir, train=False, download=True, transform=transform)
root=data_dir, train=False, download=True, transform=transform
)
return trainset, testset
# __load_data_end__
@ -90,22 +94,27 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
test_abs = int(len(trainset) * 0.8)
train_subset, val_subset = random_split(
trainset, [test_abs, len(trainset) - test_abs])
trainset, [test_abs, len(trainset) - test_abs]
)
trainloader = torch.utils.data.DataLoader(
train_subset,
batch_size=int(2**config["batch_size"]),
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4)
num_workers=4,
)
valloader = torch.utils.data.DataLoader(
val_subset,
batch_size=int(2**config["batch_size"]),
batch_size=int(2 ** config["batch_size"]),
shuffle=True,
num_workers=4)
num_workers=4,
)
from ray import tune
for epoch in range(int(round(config["num_epochs"]))): # loop over the dataset multiple times
for epoch in range(
int(round(config["num_epochs"]))
): # loop over the dataset multiple times
running_loss = 0.0
epoch_steps = 0
for i, data in enumerate(trainloader, 0):
@ -126,8 +135,10 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
running_loss += loss.item()
epoch_steps += 1
if i % 2000 == 1999: # print every 2000 mini-batches
print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1,
running_loss / epoch_steps))
print(
"[%d, %5d] loss: %.3f"
% (epoch + 1, i + 1, running_loss / epoch_steps)
)
running_loss = 0.0
# Validation loss
@ -154,11 +165,12 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
# parameter in future iterations.
with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
path = os.path.join(checkpoint_dir, "checkpoint")
torch.save(
(net.state_dict(), optimizer.state_dict()), path)
torch.save((net.state_dict(), optimizer.state_dict()), path)
tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
print("Finished Training")
# __train_end__
@ -167,7 +179,8 @@ def _test_accuracy(net, device="cpu"):
trainset, testset = load_data()
testloader = torch.utils.data.DataLoader(
testset, batch_size=4, shuffle=False, num_workers=2)
testset, batch_size=4, shuffle=False, num_workers=2
)
correct = 0
total = 0
@ -181,26 +194,28 @@ def _test_accuracy(net, device="cpu"):
correct += (predicted == labels).sum().item()
return correct / total
# __test_acc_end__
# __main_begin__
def cifar10_main(
method='BlendSearch', num_samples=10, max_num_epochs=100, gpus_per_trial=1
method="BlendSearch", num_samples=10, max_num_epochs=100, gpus_per_trial=1
):
data_dir = os.path.abspath("test/data")
load_data(data_dir) # Download data for all trials before starting the run
if method == 'BlendSearch':
if method == "BlendSearch":
from flaml import tune
else:
from ray import tune
if method in ['BOHB']:
if method in ["BOHB"]:
config = {
"l1": tune.randint(2, 8),
"l2": tune.randint(2, 8),
"lr": tune.loguniform(1e-4, 1e-1),
"num_epochs": tune.qloguniform(1, max_num_epochs, q=1),
"batch_size": tune.randint(1, 4)
"batch_size": tune.randint(1, 4),
}
else:
config = {
@ -208,13 +223,14 @@ def cifar10_main(
"l2": tune.randint(2, 9),
"lr": tune.loguniform(1e-4, 1e-1),
"num_epochs": tune.loguniform(1, max_num_epochs),
"batch_size": tune.randint(1, 5)
"batch_size": tune.randint(1, 5),
}
import ray
time_budget_s = 600
np.random.seed(7654321)
start_time = time.time()
if method == 'BlendSearch':
if method == "BlendSearch":
result = tune.run(
ray.tune.with_parameters(train_cifar, data_dir=data_dir),
config=config,
@ -225,43 +241,51 @@ def cifar10_main(
min_resource=1,
report_intermediate_result=True,
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
local_dir='logs/',
local_dir="logs/",
num_samples=num_samples,
time_budget_s=time_budget_s,
use_ray=True)
use_ray=True,
)
else:
if 'ASHA' == method:
if "ASHA" == method:
algo = None
elif 'BOHB' == method:
elif "BOHB" == method:
from ray.tune.schedulers import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
algo = TuneBOHB()
scheduler = HyperBandForBOHB(max_t=max_num_epochs)
elif 'Optuna' == method:
elif "Optuna" == method:
from ray.tune.suggest.optuna import OptunaSearch
algo = OptunaSearch(seed=10)
elif 'CFO' == method:
elif "CFO" == method:
from flaml import CFO
algo = CFO(low_cost_partial_config={
"num_epochs": 1,
})
elif 'Nevergrad' == method:
algo = CFO(
low_cost_partial_config={
"num_epochs": 1,
}
)
elif "Nevergrad" == method:
from ray.tune.suggest.nevergrad import NevergradSearch
import nevergrad as ng
algo = NevergradSearch(optimizer=ng.optimizers.OnePlusOne)
if method != 'BOHB':
if method != "BOHB":
from ray.tune.schedulers import ASHAScheduler
scheduler = ASHAScheduler(
max_t=max_num_epochs,
grace_period=1)
scheduler = ASHAScheduler(max_t=max_num_epochs, grace_period=1)
result = tune.run(
tune.with_parameters(train_cifar, data_dir=data_dir),
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
config=config,
metric="loss",
mode="min",
num_samples=num_samples, time_budget_s=time_budget_s,
scheduler=scheduler, search_alg=algo
num_samples=num_samples,
time_budget_s=time_budget_s,
scheduler=scheduler,
search_alg=algo,
)
ray.shutdown()
@ -270,13 +294,18 @@ def cifar10_main(
logger.info(f"time={time.time()-start_time}")
best_trial = result.get_best_trial("loss", "min", "all")
logger.info("Best trial config: {}".format(best_trial.config))
logger.info("Best trial final validation loss: {}".format(
best_trial.metric_analysis["loss"]["min"]))
logger.info("Best trial final validation accuracy: {}".format(
best_trial.metric_analysis["accuracy"]["max"]))
logger.info(
"Best trial final validation loss: {}".format(
best_trial.metric_analysis["loss"]["min"]
)
)
logger.info(
"Best trial final validation accuracy: {}".format(
best_trial.metric_analysis["accuracy"]["max"]
)
)
best_trained_model = Net(2**best_trial.config["l1"],
2**best_trial.config["l2"])
best_trained_model = Net(2 ** best_trial.config["l1"], 2 ** best_trial.config["l2"])
device = "cpu"
if torch.cuda.is_available():
device = "cuda:0"
@ -291,6 +320,8 @@ def cifar10_main(
test_acc = _test_accuracy(best_trained_model, device)
logger.info("Best trial test set accuracy: {}".format(test_acc))
# __main_end__
@ -303,28 +334,23 @@ def _test_cifar10_bs():
def _test_cifar10_cfo():
cifar10_main('CFO',
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
cifar10_main("CFO", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
def _test_cifar10_optuna():
cifar10_main('Optuna',
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
cifar10_main("Optuna", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
def _test_cifar10_asha():
cifar10_main('ASHA',
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
cifar10_main("ASHA", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
def _test_cifar10_bohb():
cifar10_main('BOHB',
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
cifar10_main("BOHB", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
def _test_cifar10_nevergrad():
cifar10_main('Nevergrad',
num_samples=num_samples, gpus_per_trial=gpus_per_trial)
cifar10_main("Nevergrad", num_samples=num_samples, gpus_per_trial=gpus_per_trial)
if __name__ == "__main__":

View File

@ -1,7 +1,19 @@
from flaml.tune.sample import (
BaseSampler, PolynomialExpansionSet, Domain,
uniform, quniform, choice, randint, qrandint, randn,
qrandn, loguniform, qloguniform, lograndint, qlograndint)
BaseSampler,
PolynomialExpansionSet,
Domain,
uniform,
quniform,
choice,
randint,
qrandint,
randn,
qrandn,
loguniform,
qloguniform,
lograndint,
qlograndint,
)
def test_sampler():

View File

@ -1,11 +1,12 @@
'''Require: pip install flaml[test,ray]
'''
"""Require: pip install flaml[test,ray]
"""
from flaml.searcher.blendsearch import BlendSearch
import time
import os
from sklearn.model_selection import train_test_split
import sklearn.metrics
import sklearn.datasets
try:
from ray.tune.integration.xgboost import TuneReportCheckpointCallback
except ImportError:
@ -13,9 +14,10 @@ except ImportError:
import xgboost as xgb
import logging
logger = logging.getLogger(__name__)
os.makedirs('logs', exist_ok=True)
logger.addHandler(logging.FileHandler('logs/tune.log'))
os.makedirs("logs", exist_ok=True)
logger.addHandler(logging.FileHandler("logs/tune.log"))
logger.setLevel(logging.INFO)
@ -24,8 +26,7 @@ def train_breast_cancer(config: dict):
# Load dataset
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
# Split into train and test set
train_x, test_x, train_y, test_y = train_test_split(
data, labels, test_size=0.25)
train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25)
# Build input matrices for XGBoost
train_set = xgb.DMatrix(train_x, label=train_y)
test_set = xgb.DMatrix(test_x, label=test_y)
@ -39,24 +40,26 @@ def train_breast_cancer(config: dict):
train_set,
evals=[(test_set, "eval")],
verbose_eval=False,
callbacks=[TuneReportCheckpointCallback(filename="model.xgb")])
callbacks=[TuneReportCheckpointCallback(filename="model.xgb")],
)
def _test_xgboost(method='BlendSearch'):
def _test_xgboost(method="BlendSearch"):
try:
import ray
except ImportError:
return
if method == 'BlendSearch':
if method == "BlendSearch":
from flaml import tune
else:
from ray import tune
search_space = {
"max_depth": tune.randint(1, 9) if method in [
"BlendSearch", "BOHB", "Optuna"] else tune.randint(1, 9),
"max_depth": tune.randint(1, 9)
if method in ["BlendSearch", "BOHB", "Optuna"]
else tune.randint(1, 9),
"min_child_weight": tune.choice([1, 2, 3]),
"subsample": tune.uniform(0.5, 1.0),
"eta": tune.loguniform(1e-4, 1e-1)
"eta": tune.loguniform(1e-4, 1e-1),
}
max_iter = 10
for num_samples in [128]:
@ -66,7 +69,7 @@ def _test_xgboost(method='BlendSearch'):
ray.shutdown()
ray.init(num_cpus=n_cpu, num_gpus=0)
# ray.init(address='auto')
if method == 'BlendSearch':
if method == "BlendSearch":
analysis = tune.run(
train_breast_cancer,
config=search_space,
@ -83,70 +86,89 @@ def _test_xgboost(method='BlendSearch'):
report_intermediate_result=True,
# You can add "gpu": 0.1 to allocate GPUs
resources_per_trial={"cpu": 1},
local_dir='logs/',
local_dir="logs/",
num_samples=num_samples * n_cpu,
time_budget_s=time_budget_s,
use_ray=True)
use_ray=True,
)
else:
if 'ASHA' == method:
if "ASHA" == method:
algo = None
elif 'BOHB' == method:
elif "BOHB" == method:
from ray.tune.schedulers import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
algo = TuneBOHB(max_concurrent=n_cpu)
scheduler = HyperBandForBOHB(max_t=max_iter)
elif 'Optuna' == method:
elif "Optuna" == method:
from ray.tune.suggest.optuna import OptunaSearch
algo = OptunaSearch()
elif 'CFO' == method:
elif "CFO" == method:
from flaml import CFO
algo = CFO(low_cost_partial_config={
"max_depth": 1,
}, cat_hp_cost={
"min_child_weight": [6, 3, 2],
})
elif 'CFOCat' == method:
algo = CFO(
low_cost_partial_config={
"max_depth": 1,
},
cat_hp_cost={
"min_child_weight": [6, 3, 2],
},
)
elif "CFOCat" == method:
from flaml.searcher.cfo_cat import CFOCat
algo = CFOCat(low_cost_partial_config={
"max_depth": 1,
}, cat_hp_cost={
"min_child_weight": [6, 3, 2],
})
elif 'Dragonfly' == method:
algo = CFOCat(
low_cost_partial_config={
"max_depth": 1,
},
cat_hp_cost={
"min_child_weight": [6, 3, 2],
},
)
elif "Dragonfly" == method:
from ray.tune.suggest.dragonfly import DragonflySearch
algo = DragonflySearch()
elif 'SkOpt' == method:
elif "SkOpt" == method:
from ray.tune.suggest.skopt import SkOptSearch
algo = SkOptSearch()
elif 'Nevergrad' == method:
elif "Nevergrad" == method:
from ray.tune.suggest.nevergrad import NevergradSearch
import nevergrad as ng
algo = NevergradSearch(optimizer=ng.optimizers.OnePlusOne)
elif 'ZOOpt' == method:
elif "ZOOpt" == method:
from ray.tune.suggest.zoopt import ZOOptSearch
algo = ZOOptSearch(budget=num_samples * n_cpu)
elif 'Ax' == method:
elif "Ax" == method:
from ray.tune.suggest.ax import AxSearch
algo = AxSearch()
elif 'HyperOpt' == method:
elif "HyperOpt" == method:
from ray.tune.suggest.hyperopt import HyperOptSearch
algo = HyperOptSearch()
scheduler = None
if method != 'BOHB':
if method != "BOHB":
from ray.tune.schedulers import ASHAScheduler
scheduler = ASHAScheduler(
max_t=max_iter,
grace_period=1)
scheduler = ASHAScheduler(max_t=max_iter, grace_period=1)
analysis = tune.run(
train_breast_cancer,
metric="eval-logloss",
mode="min",
# You can add "gpu": 0.1 to allocate GPUs
resources_per_trial={"cpu": 1},
config=search_space, local_dir='logs/',
config=search_space,
local_dir="logs/",
num_samples=num_samples * n_cpu,
time_budget_s=time_budget_s,
scheduler=scheduler, search_alg=algo)
scheduler=scheduler,
search_alg=algo,
)
ray.shutdown()
# # Load the best model checkpoint
# import os
@ -154,7 +176,7 @@ def _test_xgboost(method='BlendSearch'):
# best_bst.load_model(os.path.join(analysis.best_checkpoint,
# "model.xgb"))
best_trial = analysis.get_best_trial("eval-logloss", "min", "all")
accuracy = 1. - best_trial.metric_analysis["eval-error"]["min"]
accuracy = 1.0 - best_trial.metric_analysis["eval-error"]["min"]
logloss = best_trial.metric_analysis["eval-logloss"]["min"]
logger.info(f"method={method}")
logger.info(f"n_samples={num_samples*n_cpu}")
@ -166,6 +188,7 @@ def _test_xgboost(method='BlendSearch'):
def test_nested():
from flaml import tune, CFO
search_space = {
# test nested search space
"cost_related": {
@ -175,27 +198,30 @@ def test_nested():
}
def simple_func(config):
obj = (config["cost_related"]["a"] - 4)**2 \
+ (config["b"] - config["cost_related"]["a"])**2
obj = (config["cost_related"]["a"] - 4) ** 2 + (
config["b"] - config["cost_related"]["a"]
) ** 2
tune.report(obj=obj)
tune.report(obj=obj, ab=config["cost_related"]["a"] * config["b"])
analysis = tune.run(
simple_func,
search_alg=CFO(
space=search_space, metric="obj", mode="min",
low_cost_partial_config={
"cost_related": {"a": 1}
},
space=search_space,
metric="obj",
mode="min",
low_cost_partial_config={"cost_related": {"a": 1}},
points_to_evaluate=[
{"b": .99, "cost_related": {"a": 3}},
{"b": .99, "cost_related": {"a": 2}},
{"cost_related": {"a": 8}}
{"b": 0.99, "cost_related": {"a": 3}},
{"b": 0.99, "cost_related": {"a": 2}},
{"cost_related": {"a": 8}},
],
metric_constraints=[("ab", "<=", 4)]),
local_dir='logs/',
metric_constraints=[("ab", "<=", 4)],
),
local_dir="logs/",
num_samples=-1,
time_budget_s=1)
time_budget_s=1,
)
best_trial = analysis.get_best_trial()
logger.info(f"CFO best config: {best_trial.config}")
@ -205,46 +231,47 @@ def test_nested():
simple_func,
search_alg=BlendSearch(
experimental=True,
space=search_space, metric="obj", mode="min",
low_cost_partial_config={
"cost_related": {"a": 1}
},
space=search_space,
metric="obj",
mode="min",
low_cost_partial_config={"cost_related": {"a": 1}},
points_to_evaluate=[
{"b": .99, "cost_related": {"a": 3}},
{"b": .99, "cost_related": {"a": 2}},
{"cost_related": {"a": 8}}
{"b": 0.99, "cost_related": {"a": 3}},
{"b": 0.99, "cost_related": {"a": 2}},
{"cost_related": {"a": 8}},
],
metric_constraints=[("ab", "<=", 4)]),
local_dir='logs/',
metric_constraints=[("ab", "<=", 4)],
),
local_dir="logs/",
num_samples=-1,
time_budget_s=1)
time_budget_s=1,
)
best_trial = analysis.get_best_trial()
logger.info(f"BlendSearch exp best config: {best_trial.config}")
logger.info(f"BlendSearch exp best result: {best_trial.last_result}")
points_to_evaluate = [
{"b": .99, "cost_related": {"a": 3}},
{"b": .99, "cost_related": {"a": 2}},
{"b": 0.99, "cost_related": {"a": 3}},
{"b": 0.99, "cost_related": {"a": 2}},
]
analysis = tune.run(
simple_func,
config=search_space,
low_cost_partial_config={
"cost_related": {"a": 1}
},
low_cost_partial_config={"cost_related": {"a": 1}},
points_to_evaluate=points_to_evaluate,
evaluated_rewards=[
(config["cost_related"]["a"] - 4)**2
+ (config["b"] - config["cost_related"]["a"])**2
(config["cost_related"]["a"] - 4) ** 2
+ (config["b"] - config["cost_related"]["a"]) ** 2
for config in points_to_evaluate
],
metric="obj",
mode="min",
metric_constraints=[("ab", "<=", 4)],
local_dir='logs/',
local_dir="logs/",
num_samples=-1,
time_budget_s=1)
time_budget_s=1,
)
best_trial = analysis.get_best_trial()
logger.info(f"BlendSearch best config: {best_trial.config}")
@ -256,31 +283,33 @@ def test_run_training_function_return_value():
# Test dict return value
def evaluate_config_dict(config):
metric = (round(config['x']) - 85000)**2 - config['x'] / config['y']
metric = (round(config["x"]) - 85000) ** 2 - config["x"] / config["y"]
return {"metric": metric}
tune.run(
evaluate_config_dict,
config={
'x': tune.qloguniform(lower=1, upper=100000, q=1),
'y': tune.qrandint(lower=2, upper=100000, q=2)
"x": tune.qloguniform(lower=1, upper=100000, q=1),
"y": tune.qrandint(lower=2, upper=100000, q=2),
},
metric='metric', mode='max',
metric="metric",
mode="max",
num_samples=100,
)
# Test scalar return value
def evaluate_config_scalar(config):
metric = (round(config['x']) - 85000)**2 - config['x'] / config['y']
metric = (round(config["x"]) - 85000) ** 2 - config["x"] / config["y"]
return metric
tune.run(
evaluate_config_scalar,
config={
'x': tune.qloguniform(lower=1, upper=100000, q=1),
'y': tune.qlograndint(lower=2, upper=100000, q=2)
"x": tune.qloguniform(lower=1, upper=100000, q=1),
"y": tune.qlograndint(lower=2, upper=100000, q=2),
},
num_samples=100, mode='max',
num_samples=100,
mode="max",
)
@ -289,47 +318,47 @@ def test_xgboost_bs():
def _test_xgboost_cfo():
_test_xgboost('CFO')
_test_xgboost("CFO")
def test_xgboost_cfocat():
_test_xgboost('CFOCat')
_test_xgboost("CFOCat")
def _test_xgboost_dragonfly():
_test_xgboost('Dragonfly')
_test_xgboost("Dragonfly")
def _test_xgboost_skopt():
_test_xgboost('SkOpt')
_test_xgboost("SkOpt")
def _test_xgboost_nevergrad():
_test_xgboost('Nevergrad')
_test_xgboost("Nevergrad")
def _test_xgboost_zoopt():
_test_xgboost('ZOOpt')
_test_xgboost("ZOOpt")
def _test_xgboost_ax():
_test_xgboost('Ax')
_test_xgboost("Ax")
def __test_xgboost_hyperopt():
_test_xgboost('HyperOpt')
_test_xgboost("HyperOpt")
def _test_xgboost_optuna():
_test_xgboost('Optuna')
_test_xgboost("Optuna")
def _test_xgboost_asha():
_test_xgboost('ASHA')
_test_xgboost("ASHA")
def _test_xgboost_bohb():
_test_xgboost('BOHB')
_test_xgboost("BOHB")
if __name__ == "__main__":