use zeroshot when no budget is given; custom_hp (#563)

* use zeroshot when no budget is given; custom_hp

* update Getting-Started

* protobuf version

* X_val
This commit is contained in:
Chi Wang 2022-05-28 17:22:09 -07:00 committed by GitHub
parent 7748e0ff49
commit 49e8f7f028
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 201 additions and 99 deletions

View File

@ -164,6 +164,9 @@ class SearchState:
assert (
"domain" in space
), f"{name}'s domain is missing in the search space spec {space}"
if space["domain"] is None:
# don't search this hp
continue
self._search_space_domain[name] = space["domain"]
if "low_cost_init_value" in space:
@ -527,8 +530,8 @@ class AutoML(BaseEstimator):
Use all available resources when n_jobs == -1.
log_file_name: A string of the log file name | default="". To disable logging,
set it to be an empty string "".
estimator_list: A list of strings for estimator names, or 'auto'
e.g., ```['lgbm', 'xgboost', 'xgb_limitdepth', 'catboost', 'rf', 'extra_tree']```
estimator_list: A list of strings for estimator names, or 'auto'.
e.g., ```['lgbm', 'xgboost', 'xgb_limitdepth', 'catboost', 'rf', 'extra_tree']```.
time_budget: A float number of the time budget in seconds.
Use -1 if no time limit.
max_iter: An integer of the maximal number of iterations.
@ -676,7 +679,8 @@ class AutoML(BaseEstimator):
self._state = AutoMLState()
self._state.learner_classes = {}
self._settings = settings
settings["time_budget"] = settings.get("time_budget", 60)
# no budget by default
settings["time_budget"] = settings.get("time_budget", -1)
settings["task"] = settings.get("task", "classification")
settings["n_jobs"] = settings.get("n_jobs", -1)
settings["eval_method"] = settings.get("eval_method", "auto")
@ -686,7 +690,7 @@ class AutoML(BaseEstimator):
settings["metric"] = settings.get("metric", "auto")
settings["estimator_list"] = settings.get("estimator_list", "auto")
settings["log_file_name"] = settings.get("log_file_name", "")
settings["max_iter"] = settings.get("max_iter", 1000000)
settings["max_iter"] = settings.get("max_iter") # no budget by default
settings["sample"] = settings.get("sample", True)
settings["ensemble"] = settings.get("ensemble", False)
settings["log_type"] = settings.get("log_type", "better")
@ -2061,17 +2065,18 @@ class AutoML(BaseEstimator):
task: A string of the task type, e.g.,
'classification', 'regression', 'ts_forecast_regression',
'ts_forecast_classification', 'rank', 'seq-classification',
'seq-regression', 'summarization'
'seq-regression', 'summarization'.
n_jobs: An integer of the number of threads for training | default=-1.
Use all available resources when n_jobs == -1.
log_file_name: A string of the log file name | default="". To disable logging,
set it to be an empty string "".
estimator_list: A list of strings for estimator names, or 'auto'
e.g., ```['lgbm', 'xgboost', 'xgb_limitdepth', 'catboost', 'rf', 'extra_tree']```
estimator_list: A list of strings for estimator names, or 'auto'.
e.g., ```['lgbm', 'xgboost', 'xgb_limitdepth', 'catboost', 'rf', 'extra_tree']```.
time_budget: A float number of the time budget in seconds.
Use -1 if no time limit.
max_iter: An integer of the maximal number of iterations.
NOTE: when both time_budget and max_iter are unspecified,
only one model will be trained per estimator.
sample: A boolean of whether to sample the training data during
search.
ensemble: boolean or dict | default=False. Whether to perform
@ -2252,7 +2257,9 @@ class AutoML(BaseEstimator):
else log_file_name
)
max_iter = self._settings.get("max_iter") if max_iter is None else max_iter
sample = self._settings.get("sample") if sample is None else sample
sample_is_none = sample is None
if sample_is_none:
sample = self._settings.get("sample")
ensemble = self._settings.get("ensemble") if ensemble is None else ensemble
log_type = log_type or self._settings.get("log_type")
model_history = (
@ -2280,11 +2287,9 @@ class AutoML(BaseEstimator):
split_type = split_type or self._settings.get("split_type")
hpo_method = hpo_method or self._settings.get("hpo_method")
learner_selector = learner_selector or self._settings.get("learner_selector")
starting_points = (
self._settings.get("starting_points")
if starting_points is None
else starting_points
)
no_starting_points = starting_points is None
if no_starting_points:
starting_points = self._settings.get("starting_points")
n_concurrent_trials = n_concurrent_trials or self._settings.get(
"n_concurrent_trials"
)
@ -2296,6 +2301,8 @@ class AutoML(BaseEstimator):
early_stop = (
self._settings.get("early_stop") if early_stop is None else early_stop
)
# no search budget is provided?
no_budget = time_budget == -1 and max_iter is None and not early_stop
append_log = (
self._settings.get("append_log") if append_log is None else append_log
)
@ -2374,14 +2381,6 @@ class AutoML(BaseEstimator):
self._retrain_in_budget = retrain_full == "budget" and (
eval_method == "holdout" and self._state.X_val is None
)
self._state.retrain_final = (
retrain_full is True
and eval_method == "holdout"
and (self._state.X_val is None or self._use_ray is not False)
or eval_method == "cv"
and (max_iter > 0 or retrain_full is True)
or max_iter == 1
)
self._auto_augment = auto_augment
self._min_sample_size = min_sample_size
self._prepare_data(eval_method, split_ratio, n_splits)
@ -2486,7 +2485,32 @@ class AutoML(BaseEstimator):
estimator_list += ["arima", "sarimax"]
elif "regression" != self._state.task:
estimator_list += ["lrl1"]
# When no search budget is specified
if no_budget:
max_iter = len(estimator_list)
self._learner_selector = "roundrobin"
if sample_is_none:
self._sample = False
if no_starting_points:
starting_points = "data"
logger.warning(
"No search budget is provided via time_budget or max_iter."
" Training only one model per estimator."
" To tune hyperparameters for each estimator,"
" please provide budget either via time_budget or max_iter."
)
elif max_iter is None:
# set to a large number
max_iter = 1000000
self._state.retrain_final = (
retrain_full is True
and eval_method == "holdout"
and (X_val is None or self._use_ray is not False)
or eval_method == "cv"
and (max_iter > 0 or retrain_full is True)
or max_iter == 1
)
# add custom learner
for estimator_name in estimator_list:
if estimator_name not in self._state.learner_classes:
self.add_learner(

View File

@ -25,9 +25,14 @@ def meta_feature(task, X_train, y_train, meta_feature_names):
elif each_feature_name == "NumberOfClasses":
this_feature.append(len(np.unique(y_train)) if is_classification else 0)
elif each_feature_name == "PercentageOfNumericFeatures":
this_feature.append(
X_train.select_dtypes(include=np.number).shape[1] / n_feat
)
try:
# this is feature is only supported for dataframe
this_feature.append(
X_train.select_dtypes(include=np.number).shape[1] / n_feat
)
except AttributeError:
# 'numpy.ndarray' object has no attribute 'select_dtypes'
this_feature.append(1) # all features are numeric
else:
raise ValueError("Feature {} not implemented. ".format(each_feature_name))

View File

@ -547,7 +547,6 @@ class TransformersEstimator(BaseEstimator):
add_prefix_space=True
if "roberta" in self._training_args.model_path
else False, # If roberta model, must set add_prefix_space to True to avoid the assertion error at
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/roberta/tokenization_roberta_fast.py#L249
)
@ -956,10 +955,6 @@ class LGBMEstimator(BaseEstimator):
"domain": tune.loguniform(lower=1 / 1024, upper=1.0),
"init_value": 0.1,
},
# 'subsample': {
# 'domain': tune.uniform(lower=0.1, upper=1.0),
# 'init_value': 1.0,
# },
"log_max_bin": { # log transformed with base 2
"domain": tune.lograndint(lower=3, upper=11),
"init_value": 8,

View File

@ -62,6 +62,7 @@ setuptools.setup(
"rouge_score",
"hcrystalball==0.1.10",
"seqeval",
"protobuf<4", # to prevent TypeError in ray
],
"catboost": ["catboost>=0.26"],
"blendsearch": ["optuna==2.8.0"],

View File

@ -0,0 +1,66 @@
import sys
import pytest
from flaml import AutoML, tune
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
def test_custom_hp_nlp():
from test.nlp.utils import get_toy_data_seqclassification, get_automl_settings
X_train, y_train, X_val, y_val, X_test = get_toy_data_seqclassification()
automl = AutoML()
automl_settings = get_automl_settings()
automl_settings["custom_hp"] = None
automl_settings["custom_hp"] = {
"transformer": {
"model_path": {
"domain": tune.choice(["google/electra-small-discriminator"]),
},
"num_train_epochs": {"domain": 3},
}
}
automl_settings["fit_kwargs_by_estimator"] = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
def test_custom_hp():
from sklearn.datasets import load_iris
X_train, y_train = load_iris(return_X_y=True)
automl = AutoML()
custom_hp = {
"xgboost": {
"n_estimators": {
"domain": tune.lograndint(lower=1, upper=100),
"low_cost_init_value": 1,
},
},
"rf": {
"max_leaves": {
"domain": None, # disable search
},
},
"lgbm": {
"subsample": {
"domain": tune.uniform(lower=0.1, upper=1.0),
"init_value": 1.0,
},
"subsample_freq": {
"domain": 1, # subsample_freq must > 0 to enable subsample
},
},
}
automl.fit(X_train, y_train, custom_hp=custom_hp, time_budget=2)
print(automl.best_config_per_estimator)
if __name__ == "__main__":
test_custom_hp()

View File

@ -167,9 +167,8 @@ def test_multivariate_forecast_num(budget=5):
split_idx = num_samples - time_horizon
train_df = df[:split_idx]
test_df = df[split_idx:]
X_test = test_df[
["timeStamp", "temp", "precip"]
] # test dataframe must contain values for the regressors / multivariate variables
# test dataframe must contain values for the regressors / multivariate variables
X_test = test_df[["timeStamp", "temp", "precip"]]
y_test = test_df["demand"]
# return
automl = AutoML()

View File

@ -48,6 +48,7 @@ def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
"Training duration of best run: {0:.4g} s".format(automl.best_config_train_time)
)
print(automl.model.estimator)
print(automl.best_config_per_estimator)
print("time taken to find best model:", automl.time_to_find_best_model)
""" pickle and save the automl object """
import pickle
@ -92,6 +93,11 @@ def test_automl_array():
test_automl(5, "array", "bs")
def _test_nobudget():
# needs large RAM to run this test
test_automl(-1)
def test_mlflow():
import subprocess
import sys

View File

@ -8,7 +8,7 @@ from flaml import tune
class TestWarmStart(unittest.TestCase):
def test_fit_w_freezinghp_starting_point(self, as_frame=True):
automl_experiment = AutoML()
automl = AutoML()
automl_settings = {
"time_budget": 1,
"metric": "accuracy",
@ -24,20 +24,20 @@ class TestWarmStart(unittest.TestCase):
# test drop column
X_train.columns = range(X_train.shape[1])
X_train[X_train.shape[1]] = np.zeros(len(y_train))
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
automl_val_accuracy = 1.0 - automl_experiment.best_loss
print("Best ML leaner:", automl_experiment.best_estimator)
print("Best hyperparmeter config:", automl_experiment.best_config)
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
automl_val_accuracy = 1.0 - automl.best_loss
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print("Best accuracy on validation data: {0:.4g}".format(automl_val_accuracy))
print(
"Training duration of best run: {0:.4g} s".format(
automl_experiment.best_config_train_time
automl.best_config_train_time
)
)
# 1. Get starting points from previous experiments.
starting_points = automl_experiment.best_config_per_estimator
starting_points = automl.best_config_per_estimator
print("starting_points", starting_points)
print("loss of the starting_points", automl_experiment.best_loss_per_estimator)
print("loss of the starting_points", automl.best_loss_per_estimator)
starting_point = starting_points["lgbm"]
hps_to_freeze = ["colsample_bytree", "reg_alpha", "reg_lambda", "log_max_bin"]
@ -85,8 +85,8 @@ class TestWarmStart(unittest.TestCase):
return space
new_estimator_name = "large_lgbm"
new_automl_experiment = AutoML()
new_automl_experiment.add_learner(
new_automl = AutoML()
new_automl.add_learner(
learner_name=new_estimator_name, learner_class=MyPartiallyFreezedLargeLGBM
)
@ -103,22 +103,26 @@ class TestWarmStart(unittest.TestCase):
"starting_points": {new_estimator_name: starting_point},
}
new_automl_experiment.fit(
X_train=X_train, y_train=y_train, **automl_settings_resume
)
new_automl.fit(X_train=X_train, y_train=y_train, **automl_settings_resume)
new_automl_val_accuracy = 1.0 - new_automl_experiment.best_loss
print("Best ML leaner:", new_automl_experiment.best_estimator)
print("Best hyperparmeter config:", new_automl_experiment.best_config)
new_automl_val_accuracy = 1.0 - new_automl.best_loss
print("Best ML leaner:", new_automl.best_estimator)
print("Best hyperparmeter config:", new_automl.best_config)
print(
"Best accuracy on validation data: {0:.4g}".format(new_automl_val_accuracy)
)
print(
"Training duration of best run: {0:.4g} s".format(
new_automl_experiment.best_config_train_time
new_automl.best_config_train_time
)
)
def test_nobudget(self):
automl = AutoML()
X_train, y_train = load_iris(return_X_y=True)
automl.fit(X_train, y_train)
print(automl.best_config_per_estimator)
if __name__ == "__main__":
unittest.main()

View File

@ -1,36 +0,0 @@
import sys
import pytest
from utils import get_toy_data_seqclassification, get_automl_settings
@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
def test_custom_hp_nlp():
from flaml import AutoML
import flaml
X_train, y_train, X_val, y_val, X_test = get_toy_data_seqclassification()
automl = AutoML()
automl_settings = get_automl_settings()
automl_settings["custom_hp"] = None
automl_settings["custom_hp"] = {
"transformer": {
"model_path": {
"domain": flaml.tune.choice(["google/electra-small-discriminator"]),
},
"num_train_epochs": {"domain": 3},
}
}
automl_settings["fit_kwargs_by_estimator"] = {
"transformer": {
"output_dir": "test/data/output/",
"ckpt_per_epoch": 1,
"fp16": False,
}
}
automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
if __name__ == "__main__":
test_custom_hp_nlp()

View File

@ -28,14 +28,14 @@ For example, with three lines of code, you can start using this economical and f
```python
from flaml import AutoML
automl = AutoML()
automl.fit(X_train, y_train, task="classification")
automl.fit(X_train, y_train, task="classification", time_budget=60)
```
It automatically tunes the hyperparameters and selects the best model from default learners such as LightGBM, XGBoost, random forest etc. [Customizing](Use-Cases/task-oriented-automl#customize-automlfit) the optimization metrics, learners and search spaces etc. is very easy. For example,
It automatically tunes the hyperparameters and selects the best model from default learners such as LightGBM, XGBoost, random forest etc. for the specified time budget 60 seconds. [Customizing](Use-Cases/task-oriented-automl#customize-automlfit) the optimization metrics, learners and search spaces etc. is very easy. For example,
```python
automl.add_learner("mylgbm", MyLGBMEstimator)
automl.fit(X_train, y_train, task="classification", metric=custom_metric, estimator_list=["mylgbm"])
automl.fit(X_train, y_train, task="classification", metric=custom_metric, estimator_list=["mylgbm"], time_budget=60)
```
#### [Tune user-defined function](Use-Cases/Tune-User-Defined-Function)
@ -88,7 +88,7 @@ Then, you can use it just like you use the original `LGMBClassifier`. Your other
### Where to Go Next?
* Understand the use cases for [Task-oriented AutoML](Use-Cases/task-oriented-automl) and [Tune user-defined function](Use-Cases/Tune-User-Defined-Function).
* Understand the use cases for [Task-oriented AutoML](Use-Cases/task-oriented-automl), [Tune user-defined function](Use-Cases/Tune-User-Defined-Function) and [Zero-shot AutoML](Use-Cases/Zero-Shot-AutoML).
* Find code examples under "Examples": from [AutoML - Classification](Examples/AutoML-Classification) to [Tune - PyTorch](Examples/Tune-PyTorch).
* Watch [video tutorials](https://www.youtube.com/channel/UCfU0zfFXHXdAd5x-WvFBk5A).
* Learn about [research](Research) around FLAML.

View File

@ -19,7 +19,7 @@
- 'token-classification': token classification.
- 'multichoice-classification': multichoice classification.
An optional input is `time_budget` for searching models and hyperparameters. When not specified, a default budget of 60 seconds will be used.
Two optional inputs are `time_budget` and `max_iter` for searching models and hyperparameters. When both are unspecified, only one model per estimator will be trained (using our [zero-shot](Zero-Shot-AutoML) technique).
A typical way to use `flaml.AutoML`:
@ -39,7 +39,7 @@ with open("automl.pkl", "rb") as f:
pred = automl.predict(X_test)
```
If users provide the minimal inputs only, `AutoML` uses the default settings for time budget, optimization metric, estimator list etc.
If users provide the minimal inputs only, `AutoML` uses the default settings for optimization metric, estimator list etc.
## Customize AutoML.fit()
@ -191,9 +191,6 @@ Each estimator class, built-in or not, must have a `search_space` function. In t
In the example above, we tune four hyperparameters, three integers and one float. They all follow a log-uniform distribution. "max_leaf" and "n_iter" have "low_cost_init_value" specified as their values heavily influence the training cost.
To customize the search space for a built-in estimator, use a similar approach to define a class that inherits the existing estimator. For example,
```python
@ -234,17 +231,46 @@ class XGBoost2D(XGBoostSklearnEstimator):
We override the `search_space` function to tune two hyperparameters only, "n_estimators" and "max_leaves". They are both random integers in the log space, ranging from 4 to data-dependent upper bound. The lower bound for each corresponds to low training cost, hence the "low_cost_init_value" for each is set to 4.
##### A shortcut to override the search space
One can use the `custom_hp` argument in [`AutoML.fit()`](../reference/automl#fit) to override the search space for an existing estimator quickly. For example, if you would like to temporarily change the search range of "n_estimators" of xgboost, disable searching "max_leaves" in random forest, and add "subsample" in the search space of lightgbm, you can set:
```python
custom_hp = {
"xgboost": {
"n_estimators": {
"domain": tune.lograndint(lower=new_lower, upper=new_upper),
"low_cost_init_value": new_lower,
},
},
"rf": {
"max_leaves": {
"domain": None, # disable search
},
},
"lgbm": {
"subsample": {
"domain": tune.uniform(lower=0.1, upper=1.0),
"init_value": 1.0,
},
"subsample_freq": {
"domain": 1, # subsample_freq must > 0 to enable subsample
},
},
}
```
### Constraint
There are several types of constraints you can impose.
1. End-to-end constraints on the AutoML process.
1. Constraints on the AutoML process.
- `time_budget`: constrains the wall-clock time (seconds) used by the AutoML process. We provide some tips on [how to set time budget](#how-to-set-time-budget).
- `max_iter`: constrains the maximal number of models to try in the AutoML process.
2. Constraints on the (hyperparameters of) the estimators.
2. Constraints on the constructor arguments of the estimators.
Some constraints on the estimator can be implemented via the custom learner. For example,
@ -255,7 +281,18 @@ class MonotonicXGBoostEstimator(XGBoostSklearnEstimator):
return super().search_space(**args).update({"monotone_constraints": "(1, -1)"})
```
It adds a monotonicity constraint to XGBoost. This approach can be used to set any constraint that is a parameter in the underlying estimator's constructor.
It adds a monotonicity constraint to XGBoost. This approach can be used to set any constraint that is an argument in the underlying estimator's constructor.
A shortcut to do this is to use the [`custom_hp`](#a-shortcut-to-override-the-search-space) argument:
```python
custom_hp = {
"xgboost": {
"monotone_constraints": {
"domain": "(1, -1)" # fix the domain as a constant
}
}
}
```
3. Constraints on the models tried in AutoML.
@ -267,6 +304,7 @@ For example,
```python
automl.fit(X_train, y_train, max_iter=100, train_time_limit=1, pred_time_limit=1e-3)
```
4. Constraints on the metrics of the ML model tried in AutoML.
When users provide a [custom metric function](https://microsoft.github.io/FLAML/docs/Use-Cases/Task-Oriented-AutoML#optimization-metric), which returns a primary optimization metric and a dictionary of additional metrics (typically also about the model) to log, users can also specify constraints on one or more of the metrics in the dictionary of additional metrics.