mirror of https://github.com/microsoft/autogen.git
check n_iter == 1
This commit is contained in:
parent
46b29e05c7
commit
b2d8b097d7
|
@ -325,7 +325,7 @@ class LGBMEstimator(BaseEstimator):
|
|||
):
|
||||
self.params["n_estimators"] = 1
|
||||
self._t1 = self._fit(X_train, y_train, **kwargs)
|
||||
if self._t1 >= budget:
|
||||
if self._t1 >= budget or n_iter == 1:
|
||||
# self.params["n_estimators"] = n_iter
|
||||
return self._t1
|
||||
self.params["n_estimators"] = min(n_iter, 4)
|
||||
|
@ -742,7 +742,7 @@ class CatBoostEstimator(BaseEstimator):
|
|||
X_train, y_train, cat_features=cat_features, **kwargs
|
||||
)
|
||||
CatBoostEstimator._t1 = time.time() - start_time
|
||||
if CatBoostEstimator._t1 >= budget:
|
||||
if CatBoostEstimator._t1 >= budget or n_iter == 1:
|
||||
# self.params["n_estimators"] = n_iter
|
||||
self._model = CatBoostEstimator._smallmodel
|
||||
shutil.rmtree(train_dir, ignore_errors=True)
|
||||
|
|
|
@ -37,10 +37,11 @@ class TestTrainingLog(unittest.TestCase):
|
|||
if automl.best_estimator:
|
||||
estimator, config = automl.best_estimator, automl.best_config
|
||||
model0 = automl.best_model_for_estimator(estimator)
|
||||
print(model0.estimator)
|
||||
print(model0.params["n_estimators"], model0.estimator)
|
||||
|
||||
automl.time_budget = None
|
||||
model, _ = automl._state._train_with_config(estimator, config)
|
||||
print(model.estimator)
|
||||
# model0 and model are equivalent unless model0's n_estimator is out of search space range
|
||||
assert (
|
||||
str(model0.estimator) == str(model.estimator)
|
||||
|
|
Loading…
Reference in New Issue