add save_best_config()

This commit is contained in:
Chi Wang 2021-12-04 18:27:38 +00:00
parent 54d303a95a
commit 7d269435ae
2 changed files with 25 additions and 10 deletions

View File

@ -1,10 +1,10 @@
# !
# * Copyright (c) Microsoft Corporation. All rights reserved.
# * Copyright (c) FLAML authors. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
import time
import os
from typing import Callable, Optional
from typing import Callable, Optional, List, Union
from functools import partial
import numpy as np
from scipy.sparse import issparse
@ -20,10 +20,7 @@ from sklearn.utils import shuffle
from sklearn.base import BaseEstimator
import pandas as pd
import logging
from typing import List, Union
from pandas import DataFrame
from .data import _is_nlp_task
import json
from .ml import (
compute_estimator,
train_estimator,
@ -40,8 +37,14 @@ from .config import (
N_SPLITS,
SAMPLE_MULTIPLY_FACTOR,
)
from .data import concat, CLASSIFICATION, TS_FORECAST, FORECAST, REGRESSION
from .data import (
concat,
CLASSIFICATION,
TS_FORECAST,
FORECAST,
REGRESSION,
_is_nlp_task,
)
from . import tune
from .training_log import training_log_reader, training_log_writer
@ -678,6 +681,15 @@ class AutoML(BaseEstimator):
self._search_states[self._best_estimator], "best_config_train_time", None
)
def save_best_config(self, filename):
best = {
"class": self.best_estimator,
"hyperparameters": self.best_config,
}
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, "w") as f:
json.dump(best, f)
@property
def classes_(self):
"""A list of n_classes elements for class labels."""
@ -694,7 +706,9 @@ class AutoML(BaseEstimator):
"""Time taken to find best model in seconds."""
return self.__dict__.get("_time_taken_best_iter")
def predict(self, X_test: Union[np.array, DataFrame, List[str], List[List[str]]]):
def predict(
self, X_test: Union[np.array, pd.DataFrame, List[str], List[List[str]]]
):
"""Predict label from features.
Args:
@ -763,7 +777,7 @@ class AutoML(BaseEstimator):
try:
if isinstance(X[0], List):
X = [x for x in zip(*X)]
X = DataFrame(
X = pd.DataFrame(
dict(
[
(self._transformer._str_columns[idx], X[idx])

View File

@ -119,3 +119,4 @@ class TestLogging(unittest.TestCase):
pred2 = automl.predict(X_train)
delta = pred1 - pred2
assert max(delta) == 0 and min(delta) == 0
automl.save_best_config("test/housing.json")