check n_iter == 1

This commit is contained in:
Chi Wang 2021-10-16 03:09:56 -07:00
parent 46b29e05c7
commit b2d8b097d7
2 changed files with 4 additions and 3 deletions

View File

@ -325,7 +325,7 @@ class LGBMEstimator(BaseEstimator):
):
self.params["n_estimators"] = 1
self._t1 = self._fit(X_train, y_train, **kwargs)
if self._t1 >= budget:
if self._t1 >= budget or n_iter == 1:
# self.params["n_estimators"] = n_iter
return self._t1
self.params["n_estimators"] = min(n_iter, 4)
@ -742,7 +742,7 @@ class CatBoostEstimator(BaseEstimator):
X_train, y_train, cat_features=cat_features, **kwargs
)
CatBoostEstimator._t1 = time.time() - start_time
if CatBoostEstimator._t1 >= budget:
if CatBoostEstimator._t1 >= budget or n_iter == 1:
# self.params["n_estimators"] = n_iter
self._model = CatBoostEstimator._smallmodel
shutil.rmtree(train_dir, ignore_errors=True)

View File

@ -37,10 +37,11 @@ class TestTrainingLog(unittest.TestCase):
if automl.best_estimator:
estimator, config = automl.best_estimator, automl.best_config
model0 = automl.best_model_for_estimator(estimator)
print(model0.estimator)
print(model0.params["n_estimators"], model0.estimator)
automl.time_budget = None
model, _ = automl._state._train_with_config(estimator, config)
print(model.estimator)
# model0 and model are equivalent unless model0's n_estimator is out of search space range
assert (
str(model0.estimator) == str(model.estimator)