multiple logged metrics in cv (#114)

This commit is contained in:
Chi Wang 2021-06-18 21:19:59 -07:00 committed by GitHub
parent 3a2b6cdddc
commit e039861ab0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 4 deletions

View File

@ -233,7 +233,10 @@ def evaluate_model_CV(
valid_fold_num += 1
total_val_loss += val_loss_i
if train_loss is not False:
if total_train_loss != 0:
if isinstance(total_train_loss, list):
total_train_loss = [
total_train_loss[i] + v for i, v in enumerate(train_loss_i)]
elif total_train_loss != 0:
total_train_loss += train_loss_i
else:
total_train_loss = train_loss_i
@ -246,7 +249,10 @@ def evaluate_model_CV(
break
val_loss = np.max(val_loss_list)
if train_loss is not False:
train_loss = total_train_loss / n
if isinstance(total_train_loss, list):
train_loss = [v / n for v in total_train_loss]
else:
train_loss = total_train_loss / n
budget -= time.time() - start_time
if val_loss < best_val_loss and budget > budget_per_train:
estimator.cleanup()

View File

@ -153,8 +153,8 @@ class TestAutoML(unittest.TestCase):
X_train, y_train = load_iris(return_X_y=True)
automl_experiment = AutoML()
automl_settings = {
"time_budget": 10,
'eval_method': 'holdout',
"time_budget": 5,
'eval_method': 'cv',
"metric": custom_metric,
"task": 'classification',
"log_file_name": "test/iris_custom.log",