mirror of https://github.com/microsoft/autogen.git
enable ensemble when using ray (#583)
* enable ensemble when using ray * sanitize config
This commit is contained in:
parent
0642b6e7bb
commit
f8cc38bc16
|
@ -385,6 +385,15 @@ class AutoMLState:
|
|||
tune.report(**result)
|
||||
return result
|
||||
|
||||
def sanitize(self, config: dict) -> dict:
|
||||
"""Make a config ready for passing to estimator."""
|
||||
config = config.get("ml", config).copy()
|
||||
if "FLAML_sample_size" in config:
|
||||
del config["FLAML_sample_size"]
|
||||
if "learner" in config:
|
||||
del config["learner"]
|
||||
return config
|
||||
|
||||
def _train_with_config(
|
||||
self,
|
||||
estimator,
|
||||
|
@ -395,11 +404,7 @@ class AutoMLState:
|
|||
sample_size = config_w_resource.get(
|
||||
"FLAML_sample_size", len(self.y_train_all)
|
||||
)
|
||||
config = config_w_resource.get("ml", config_w_resource).copy()
|
||||
if "FLAML_sample_size" in config:
|
||||
del config["FLAML_sample_size"]
|
||||
if "learner" in config:
|
||||
del config["learner"]
|
||||
config = self.sanitize(config_w_resource)
|
||||
|
||||
this_estimator_kwargs = self.fit_kwargs_by_estimator.get(
|
||||
estimator
|
||||
|
@ -3203,7 +3208,7 @@ class AutoML(BaseEstimator):
|
|||
x[1].learner_class(
|
||||
task=self._state.task,
|
||||
n_jobs=self._state.n_jobs,
|
||||
**x[1].best_config,
|
||||
**self._state.sanitize(x[1].best_config),
|
||||
),
|
||||
)
|
||||
for x in search_states[:2]
|
||||
|
@ -3214,13 +3219,15 @@ class AutoML(BaseEstimator):
|
|||
x[1].learner_class(
|
||||
task=self._state.task,
|
||||
n_jobs=self._state.n_jobs,
|
||||
**x[1].best_config,
|
||||
**self._state.sanitize(x[1].best_config),
|
||||
),
|
||||
)
|
||||
for x in search_states[2:]
|
||||
if x[1].best_loss < 4 * self._selected.best_loss
|
||||
]
|
||||
logger.info(estimators)
|
||||
logger.info(
|
||||
[(estimator[0], estimator[1].params) for estimator in estimators]
|
||||
)
|
||||
if len(estimators) > 1:
|
||||
if self._state.task in CLASSIFICATION:
|
||||
from sklearn.ensemble import StackingClassifier as Stacker
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = "1.0.6"
|
||||
__version__ = "1.0.7"
|
||||
|
|
|
@ -256,6 +256,7 @@ class TestClassification(unittest.TestCase):
|
|||
time_budget=10,
|
||||
task="classification",
|
||||
n_concurrent_trials=2,
|
||||
ensemble=True,
|
||||
)
|
||||
except ImportError:
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue