mirror of https://github.com/microsoft/autogen.git
stratified group kfold splitter (#899)
* stratified group kfold splitter * exclude catboost --------- Co-authored-by: Shaokun <shaokunzhang529@gmail.com> Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
parent
cb3378d621
commit
fbea1d06dd
|
@ -17,6 +17,7 @@ from sklearn.model_selection import (
|
||||||
GroupKFold,
|
GroupKFold,
|
||||||
TimeSeriesSplit,
|
TimeSeriesSplit,
|
||||||
GroupShuffleSplit,
|
GroupShuffleSplit,
|
||||||
|
StratifiedGroupKFold,
|
||||||
)
|
)
|
||||||
from sklearn.utils import shuffle
|
from sklearn.utils import shuffle
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
|
@ -1575,8 +1576,8 @@ class AutoML(BaseEstimator):
|
||||||
else:
|
else:
|
||||||
# logger.info("Using splitter object")
|
# logger.info("Using splitter object")
|
||||||
self._state.kf = self._split_type
|
self._state.kf = self._split_type
|
||||||
if isinstance(self._state.kf, GroupKFold):
|
if isinstance(self._state.kf, (GroupKFold, StratifiedGroupKFold)):
|
||||||
# self._split_type is either "group" or a GroupKFold object
|
# self._split_type is either "group", a GroupKFold object, or a StratifiedGroupKFold object
|
||||||
self._state.kf.groups = self._state.groups_all
|
self._state.kf.groups = self._state.groups_all
|
||||||
|
|
||||||
def add_learner(self, learner_name, learner_class):
|
def add_learner(self, learner_name, learner_class):
|
||||||
|
|
|
@ -17,7 +17,12 @@ from sklearn.metrics import (
|
||||||
mean_absolute_percentage_error,
|
mean_absolute_percentage_error,
|
||||||
ndcg_score,
|
ndcg_score,
|
||||||
)
|
)
|
||||||
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
|
from sklearn.model_selection import (
|
||||||
|
RepeatedStratifiedKFold,
|
||||||
|
GroupKFold,
|
||||||
|
TimeSeriesSplit,
|
||||||
|
StratifiedGroupKFold,
|
||||||
|
)
|
||||||
from flaml.automl.model import (
|
from flaml.automl.model import (
|
||||||
XGBoostSklearnEstimator,
|
XGBoostSklearnEstimator,
|
||||||
XGBoost_TS,
|
XGBoost_TS,
|
||||||
|
@ -517,7 +522,7 @@ def evaluate_model_CV(
|
||||||
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
|
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
|
||||||
if isinstance(kf, RepeatedStratifiedKFold):
|
if isinstance(kf, RepeatedStratifiedKFold):
|
||||||
kf = kf.split(X_train_split, y_train_split)
|
kf = kf.split(X_train_split, y_train_split)
|
||||||
elif isinstance(kf, GroupKFold):
|
elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
|
||||||
groups = kf.groups
|
groups = kf.groups
|
||||||
kf = kf.split(X_train_split, y_train_split, groups)
|
kf = kf.split(X_train_split, y_train_split, groups)
|
||||||
shuffle = False
|
shuffle = False
|
||||||
|
@ -548,8 +553,16 @@ def evaluate_model_CV(
|
||||||
weight[val_index],
|
weight[val_index],
|
||||||
)
|
)
|
||||||
if groups is not None:
|
if groups is not None:
|
||||||
fit_kwargs["groups"] = groups[train_index]
|
fit_kwargs["groups"] = (
|
||||||
groups_val = groups[val_index]
|
groups[train_index]
|
||||||
|
if isinstance(groups, np.ndarray)
|
||||||
|
else groups.iloc[train_index]
|
||||||
|
)
|
||||||
|
groups_val = (
|
||||||
|
groups[val_index]
|
||||||
|
if isinstance(groups, np.ndarray)
|
||||||
|
else groups.iloc[val_index]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
groups_val = None
|
groups_val = None
|
||||||
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
|
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
|
||||||
|
|
|
@ -94,6 +94,33 @@ def test_groups():
|
||||||
automl.fit(X, y, **automl_settings)
|
automl.fit(X, y, **automl_settings)
|
||||||
|
|
||||||
|
|
||||||
|
def test_stratified_groupkfold():
|
||||||
|
from sklearn.model_selection import StratifiedGroupKFold
|
||||||
|
from flaml.data import load_openml_dataset
|
||||||
|
|
||||||
|
X_train, _, y_train, _ = load_openml_dataset(dataset_id=1169, data_dir="test/")
|
||||||
|
splitter = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)
|
||||||
|
|
||||||
|
automl = AutoML()
|
||||||
|
settings = {
|
||||||
|
"time_budget": 6,
|
||||||
|
"metric": "ap",
|
||||||
|
"eval_method": "cv",
|
||||||
|
"split_type": splitter,
|
||||||
|
"groups": X_train["Airline"],
|
||||||
|
"estimator_list": [
|
||||||
|
"lgbm",
|
||||||
|
"rf",
|
||||||
|
"xgboost",
|
||||||
|
"extra_tree",
|
||||||
|
"xgb_limitdepth",
|
||||||
|
"lrl1",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
automl.fit(X_train=X_train, y_train=y_train, **settings)
|
||||||
|
|
||||||
|
|
||||||
def test_rank():
|
def test_rank():
|
||||||
from sklearn.externals._arff import ArffException
|
from sklearn.externals._arff import ArffException
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue