stratified group kfold splitter (#899)

* stratified group kfold splitter

* exclude catboost

---------

Co-authored-by: Shaokun <shaokunzhang529@gmail.com>
Co-authored-by: Qingyun Wu <qingyun.wu@psu.edu>
This commit is contained in:
Chi Wang 2023-02-05 15:26:14 -08:00 committed by GitHub
parent cb3378d621
commit fbea1d06dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 6 deletions

View File

@ -17,6 +17,7 @@ from sklearn.model_selection import (
GroupKFold,
TimeSeriesSplit,
GroupShuffleSplit,
StratifiedGroupKFold,
)
from sklearn.utils import shuffle
from sklearn.base import BaseEstimator
@ -1575,8 +1576,8 @@ class AutoML(BaseEstimator):
else:
# logger.info("Using splitter object")
self._state.kf = self._split_type
if isinstance(self._state.kf, GroupKFold):
# self._split_type is either "group" or a GroupKFold object
if isinstance(self._state.kf, (GroupKFold, StratifiedGroupKFold)):
# self._split_type is either "group", a GroupKFold object, or a StratifiedGroupKFold object
self._state.kf.groups = self._state.groups_all
def add_learner(self, learner_name, learner_class):

View File

@ -17,7 +17,12 @@ from sklearn.metrics import (
mean_absolute_percentage_error,
ndcg_score,
)
from sklearn.model_selection import RepeatedStratifiedKFold, GroupKFold, TimeSeriesSplit
from sklearn.model_selection import (
RepeatedStratifiedKFold,
GroupKFold,
TimeSeriesSplit,
StratifiedGroupKFold,
)
from flaml.automl.model import (
XGBoostSklearnEstimator,
XGBoost_TS,
@ -517,7 +522,7 @@ def evaluate_model_CV(
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):
elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
groups = kf.groups
kf = kf.split(X_train_split, y_train_split, groups)
shuffle = False
@ -548,8 +553,16 @@ def evaluate_model_CV(
weight[val_index],
)
if groups is not None:
fit_kwargs["groups"] = groups[train_index]
groups_val = groups[val_index]
fit_kwargs["groups"] = (
groups[train_index]
if isinstance(groups, np.ndarray)
else groups.iloc[train_index]
)
groups_val = (
groups[val_index]
if isinstance(groups, np.ndarray)
else groups.iloc[val_index]
)
else:
groups_val = None
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(

View File

@ -94,6 +94,33 @@ def test_groups():
automl.fit(X, y, **automl_settings)
def test_stratified_groupkfold():
from sklearn.model_selection import StratifiedGroupKFold
from flaml.data import load_openml_dataset
X_train, _, y_train, _ = load_openml_dataset(dataset_id=1169, data_dir="test/")
splitter = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=0)
automl = AutoML()
settings = {
"time_budget": 6,
"metric": "ap",
"eval_method": "cv",
"split_type": splitter,
"groups": X_train["Airline"],
"estimator_list": [
"lgbm",
"rf",
"xgboost",
"extra_tree",
"xgb_limitdepth",
"lrl1",
],
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
def test_rank():
from sklearn.externals._arff import ArffException