diff --git a/flaml/ml.py b/flaml/ml.py index e0872db1a..ffcdfeff1 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -459,7 +459,7 @@ def evaluate_model_CV( "label_list" ) # pass the label list on to compute the evaluation metric groups = None - shuffle = False if task in TS_FORECAST else True + shuffle = getattr(kf, "shuffle", task not in TS_FORECAST) if isinstance(kf, RepeatedStratifiedKFold): kf = kf.split(X_train_split, y_train_split) elif isinstance(kf, GroupKFold): diff --git a/test/automl/test_split.py b/test/automl/test_split.py index b40631cb2..7eb8c7b50 100644 --- a/test/automl/test_split.py +++ b/test/automl/test_split.py @@ -174,6 +174,11 @@ def test_object(): automl._state.eval_method == "cv" ), "eval_method must be 'cv' for custom data splitter" + kf = TestKFold(5) + kf.shuffle = True + automl_settings["split_type"] = kf + automl.fit(X, y, **automl_settings) + if __name__ == "__main__": test_groups() diff --git a/website/docs/Use-Cases/Task-Oriented-AutoML.md b/website/docs/Use-Cases/Task-Oriented-AutoML.md index fabd0de89..6e427df7d 100644 --- a/website/docs/Use-Cases/Task-Oriented-AutoML.md +++ b/website/docs/Use-Cases/Task-Oriented-AutoML.md @@ -364,7 +364,7 @@ For both classification and regression, time-based split can be enforced if the When `eval_method="cv"`, `split_type` can also be set as a custom splitter. It needs to be an instance of a derived class of scikit-learn [KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold) -and have ``split`` and ``get_n_splits`` methods with the same signatures. +and have ``split`` and ``get_n_splits`` methods with the same signatures. To disable shuffling, the splitter instance must contain the attribute `shuffle=False`. ### Parallel tuning