Fix decide_split_type bug. (#184)

* Fix decide_split_type bug.
This commit is contained in:
Gian Pio Domiziani 2021-09-02 17:50:22 +02:00 committed by GitHub
parent ec34427ca8
commit 63bba92fd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 2 deletions

View File

@ -911,7 +911,7 @@ class AutoML:
return training_duration
def _decide_split_type(self, split_type):
if self._state.task == 'classification':
if self._state.task in ('classification', 'binary', 'multi'):
self._state.task = get_classification_objective(
len(np.unique(self._y_train_all)))
assert split_type in [None, "stratified", "uniform", "time"]

View File

@ -2,7 +2,7 @@ import unittest
import numpy as np
import scipy.sparse
from sklearn.datasets import load_boston, load_iris, load_wine
from sklearn.datasets import load_boston, load_iris, load_wine, load_breast_cancer
import pandas as pd
from datetime import datetime
@ -282,6 +282,20 @@ class TestAutoML(unittest.TestCase):
filename=automl_settings['log_file_name'], time_budget=6)
print(metric_history)
def test_binary(self):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 1,
"task": 'binary',
"log_file_name": "test/breast_cancer.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True
}
X_train, y_train = load_breast_cancer(return_X_y=True)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
_ = automl_experiment.predict(X_train)
def test_classification(self, as_frame=False):
automl_experiment = AutoML()
automl_settings = {