enable ensemble when using ray (#583)

* enable ensemble when using ray

* sanitize config
This commit is contained in:
Chi Wang 2022-06-10 21:28:47 -07:00 committed by GitHub
parent 0642b6e7bb
commit f8cc38bc16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 9 deletions

View File

@ -385,6 +385,15 @@ class AutoMLState:
tune.report(**result)
return result
def sanitize(self, config: dict) -> dict:
"""Make a config ready for passing to estimator."""
config = config.get("ml", config).copy()
if "FLAML_sample_size" in config:
del config["FLAML_sample_size"]
if "learner" in config:
del config["learner"]
return config
def _train_with_config(
self,
estimator,
@ -395,11 +404,7 @@ class AutoMLState:
sample_size = config_w_resource.get(
"FLAML_sample_size", len(self.y_train_all)
)
config = config_w_resource.get("ml", config_w_resource).copy()
if "FLAML_sample_size" in config:
del config["FLAML_sample_size"]
if "learner" in config:
del config["learner"]
config = self.sanitize(config_w_resource)
this_estimator_kwargs = self.fit_kwargs_by_estimator.get(
estimator
@ -3203,7 +3208,7 @@ class AutoML(BaseEstimator):
x[1].learner_class(
task=self._state.task,
n_jobs=self._state.n_jobs,
**x[1].best_config,
**self._state.sanitize(x[1].best_config),
),
)
for x in search_states[:2]
@ -3214,13 +3219,15 @@ class AutoML(BaseEstimator):
x[1].learner_class(
task=self._state.task,
n_jobs=self._state.n_jobs,
**x[1].best_config,
**self._state.sanitize(x[1].best_config),
),
)
for x in search_states[2:]
if x[1].best_loss < 4 * self._selected.best_loss
]
logger.info(estimators)
logger.info(
[(estimator[0], estimator[1].params) for estimator in estimators]
)
if len(estimators) > 1:
if self._state.task in CLASSIFICATION:
from sklearn.ensemble import StackingClassifier as Stacker

View File

@ -1 +1 @@
__version__ = "1.0.6"
__version__ = "1.0.7"

View File

@ -256,6 +256,7 @@ class TestClassification(unittest.TestCase):
time_budget=10,
task="classification",
n_concurrent_trials=2,
ensemble=True,
)
except ImportError:
return