no search when max_iter < 2

This commit is contained in:
Chi Wang 2021-10-16 01:11:12 -07:00
parent 5fb843234f
commit b03a87e737
4 changed files with 53 additions and 15 deletions

View File

@ -104,6 +104,7 @@ class SearchState:
self.trained_estimator = None self.trained_estimator = None
self.sample_size = None self.sample_size = None
self.trial_time = 0 self.trial_time = 0
self.best_n_iter = None
def update(self, result, time_used, save_model_history=False): def update(self, result, time_used, save_model_history=False):
if result: if result:
@ -430,7 +431,7 @@ class AutoML:
@property @property
def time_to_find_best_model(self) -> float: def time_to_find_best_model(self) -> float:
"""time taken to find best model in seconds""" """Time taken to find best model in seconds"""
return self.__dict__.get("_time_taken_best_iter") return self.__dict__.get("_time_taken_best_iter")
def predict(self, X_test): def predict(self, X_test):
@ -1768,6 +1769,17 @@ class AutoML:
better = True # whether we find a better model in one trial better = True # whether we find a better model in one trial
if self._ensemble: if self._ensemble:
self.best_model = {} self.best_model = {}
if self._max_iter < 2 and self.estimator_list:
# when max_iter is 1, no need to search
self._max_iter = 0
self._best_estimator = estimator = self.estimator_list[0]
self._selected = state = self._search_states[estimator]
state.best_config_sample_size = self._state.data_size
state.best_config = (
state.init_config
if isinstance(state.init_config, dict)
else state.init_config[0]
)
for self._track_iter in range(self._max_iter): for self._track_iter in range(self._max_iter):
if self._estimator_index is None: if self._estimator_index is None:
estimator = self._active_estimators[0] estimator = self._active_estimators[0]
@ -1844,9 +1856,9 @@ class AutoML:
metric="val_loss", metric="val_loss",
mode="min", mode="min",
space=search_space, space=search_space,
points_to_evaluate=points_to_evaluate points_to_evaluate=[
if len(search_state.init_config) == len(search_space) p for p in points_to_evaluate if len(p) == len(search_space)
else None, ],
) )
search_state.search_alg = ConcurrencyLimiter(algo, max_concurrent=1) search_state.search_alg = ConcurrencyLimiter(algo, max_concurrent=1)
# search_state.search_alg = algo # search_state.search_alg = algo

View File

@ -25,17 +25,43 @@ class TestTrainingLog(unittest.TestCase):
"mem_thres": 1024 * 1024, "mem_thres": 1024 * 1024,
"n_jobs": 1, "n_jobs": 1,
"model_history": True, "model_history": True,
"train_time_limit": 0.01, "train_time_limit": 0.1,
"verbose": 3, "verbose": 3,
"ensemble": True, "ensemble": True,
"keep_search_state": True, "keep_search_state": True,
} }
X_train, y_train = fetch_california_housing(return_X_y=True) X_train, y_train = fetch_california_housing(return_X_y=True)
automl.fit(X_train=X_train, y_train=y_train, **automl_settings) automl.fit(X_train=X_train, y_train=y_train, **automl_settings)
automl._state._train_with_config(automl.best_estimator, automl.best_config)
# Check if the training log file is populated. # Check if the training log file is populated.
self.assertTrue(os.path.exists(filename)) self.assertTrue(os.path.exists(filename))
if automl.best_estimator:
estimator, config = automl.best_estimator, automl.best_config
model0 = automl.best_model_for_estimator(estimator)
print(model0.estimator)
automl.time_budget = None
model, _ = automl._state._train_with_config(estimator, config)
# model0 and model are equivalent unless model0's n_estimator is out of search space range
assert (
str(model0.estimator) == str(model.estimator)
or model0["n_estimators"] < 4
)
# assuming estimator & config are saved and loaded as follows
automl = AutoML()
automl.fit(
X_train=X_train,
y_train=y_train,
max_iter=0,
task="regression",
estimator_list=[estimator],
n_jobs=1,
starting_points={estimator: config},
)
# then the fitted model should be equivalent to model
# print(str(model.estimator), str(automl.model.estimator))
assert str(model.estimator) == str(automl.model.estimator)
with training_log_reader(filename) as reader: with training_log_reader(filename) as reader:
count = 0 count = 0
for record in reader.records(): for record in reader.records():