mirror of https://github.com/microsoft/autogen.git
add save_best_config()
This commit is contained in:
parent
54d303a95a
commit
7d269435ae
|
@ -1,10 +1,10 @@
|
|||
# !
|
||||
# * Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# * Copyright (c) FLAML authors. All rights reserved.
|
||||
# * Licensed under the MIT License. See LICENSE file in the
|
||||
# * project root for license information.
|
||||
import time
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, List, Union
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from scipy.sparse import issparse
|
||||
|
@ -20,10 +20,7 @@ from sklearn.utils import shuffle
|
|||
from sklearn.base import BaseEstimator
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import List, Union
|
||||
from pandas import DataFrame
|
||||
from .data import _is_nlp_task
|
||||
|
||||
import json
|
||||
from .ml import (
|
||||
compute_estimator,
|
||||
train_estimator,
|
||||
|
@ -40,8 +37,14 @@ from .config import (
|
|||
N_SPLITS,
|
||||
SAMPLE_MULTIPLY_FACTOR,
|
||||
)
|
||||
|
||||
from .data import concat, CLASSIFICATION, TS_FORECAST, FORECAST, REGRESSION
|
||||
from .data import (
|
||||
concat,
|
||||
CLASSIFICATION,
|
||||
TS_FORECAST,
|
||||
FORECAST,
|
||||
REGRESSION,
|
||||
_is_nlp_task,
|
||||
)
|
||||
from . import tune
|
||||
from .training_log import training_log_reader, training_log_writer
|
||||
|
||||
|
@ -678,6 +681,15 @@ class AutoML(BaseEstimator):
|
|||
self._search_states[self._best_estimator], "best_config_train_time", None
|
||||
)
|
||||
|
||||
def save_best_config(self, filename):
|
||||
best = {
|
||||
"class": self.best_estimator,
|
||||
"hyperparameters": self.best_config,
|
||||
}
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
with open(filename, "w") as f:
|
||||
json.dump(best, f)
|
||||
|
||||
@property
|
||||
def classes_(self):
|
||||
"""A list of n_classes elements for class labels."""
|
||||
|
@ -694,7 +706,9 @@ class AutoML(BaseEstimator):
|
|||
"""Time taken to find best model in seconds."""
|
||||
return self.__dict__.get("_time_taken_best_iter")
|
||||
|
||||
def predict(self, X_test: Union[np.array, DataFrame, List[str], List[List[str]]]):
|
||||
def predict(
|
||||
self, X_test: Union[np.array, pd.DataFrame, List[str], List[List[str]]]
|
||||
):
|
||||
"""Predict label from features.
|
||||
|
||||
Args:
|
||||
|
@ -763,7 +777,7 @@ class AutoML(BaseEstimator):
|
|||
try:
|
||||
if isinstance(X[0], List):
|
||||
X = [x for x in zip(*X)]
|
||||
X = DataFrame(
|
||||
X = pd.DataFrame(
|
||||
dict(
|
||||
[
|
||||
(self._transformer._str_columns[idx], X[idx])
|
||||
|
|
|
@ -119,3 +119,4 @@ class TestLogging(unittest.TestCase):
|
|||
pred2 = automl.predict(X_train)
|
||||
delta = pred1 - pred2
|
||||
assert max(delta) == 0 and min(delta) == 0
|
||||
automl.save_best_config("test/housing.json")
|
||||
|
|
Loading…
Reference in New Issue