Disable shuffle for custom CV (#659)

* Disable shuffle for custom CV

* Add custom fold shuffle test

* Update test_split.py

* Update test_split.py
This commit is contained in:
jmrichardson 2022-08-12 20:05:32 -04:00 committed by GitHub
parent ca9f9054e7
commit e43485607a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 2 deletions

View File

@ -459,7 +459,7 @@ def evaluate_model_CV(
"label_list"
) # pass the label list on to compute the evaluation metric
groups = None
shuffle = False if task in TS_FORECAST else True
shuffle = getattr(kf, "shuffle", task not in TS_FORECAST)
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):

View File

@ -174,6 +174,11 @@ def test_object():
automl._state.eval_method == "cv"
), "eval_method must be 'cv' for custom data splitter"
kf = TestKFold(5)
kf.shuffle = True
automl_settings["split_type"] = kf
automl.fit(X, y, **automl_settings)
if __name__ == "__main__":
test_groups()

View File

@ -364,7 +364,7 @@ For both classification and regression, time-based split can be enforced if the
When `eval_method="cv"`, `split_type` can also be set as a custom splitter. It needs to be an instance of a derived class of scikit-learn
[KFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
and have ``split`` and ``get_n_splits`` methods with the same signatures.
and have ``split`` and ``get_n_splits`` methods with the same signatures. To disable shuffling, the splitter instance must contain the attribute `shuffle=False`.
### Parallel tuning