create an automl option to remove unnecessary dependency for autogen and tune (#1007)

* version update post release v1.2.2

* automl option

* import pandas

* remove automl.utils

* default

* test

* type hint and version update

* dependency update

* link to open in colab

* use packging.version to close #725

---------

Co-authored-by: Li Jiang <lijiang1@microsoft.com>
Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Chi Wang 2023-05-24 16:55:04 -07:00 committed by GitHub
parent e9fdbc6e02
commit a0b318b12e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
48 changed files with 2013 additions and 1154 deletions

View File

@ -9,8 +9,6 @@ import sys
from typing import Callable, List, Union, Optional
from functools import partial
import numpy as np
from sklearn.base import BaseEstimator
import pandas as pd
import logging
import json
@ -38,36 +36,18 @@ from flaml.automl.logger import logger, logger_formatter
from flaml.automl.training_log import training_log_reader, training_log_writer
from flaml.default import suggest_learner
from flaml.version import __version__ as flaml_version
from flaml.automl.spark import psDataFrame, psSeries, DataFrame, Series
from flaml.tune.spark.utils import check_spark, get_broadcast_data
ERROR = (
DataFrame is None and ImportError("please install flaml[automl] option to use the flaml.automl package.") or None
)
try:
from flaml.automl.spark.utils import (
train_test_split_pyspark,
unique_pandas_on_spark,
len_labels,
unique_value_first_index,
)
from sklearn.base import BaseEstimator
except ImportError:
train_test_split_pyspark = None
unique_pandas_on_spark = None
from flaml.automl.utils import (
len_labels,
unique_value_first_index,
)
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
from pyspark.pandas.config import set_option, reset_option
except ImportError:
ps = None
class psDataFrame:
pass
class psSeries:
pass
BaseEstimator = object
ERROR = ERROR or ImportError("please install flaml[automl] option to use the flaml.automl package.")
try:
import mlflow
@ -78,7 +58,6 @@ try:
from ray import __version__ as ray_version
assert ray_version >= "1.10.0"
ray_available = True
except (ImportError, AssertionError):
ray_available = False
@ -346,6 +325,8 @@ class AutoML(BaseEstimator):
FLAML will create nested runs.
"""
if ERROR:
raise ERROR
self._track_iter = 0
self._state = AutoMLState()
self._state.learner_classes = {}
@ -540,8 +521,8 @@ class AutoML(BaseEstimator):
def score(
self,
X: Union[pd.DataFrame, psDataFrame],
y: Union[pd.Series, psSeries],
X: Union[DataFrame, psDataFrame],
y: Union[Series, psSeries],
**kwargs,
):
estimator = getattr(self, "_trained_estimator", None)
@ -555,7 +536,7 @@ class AutoML(BaseEstimator):
def predict(
self,
X: Union[np.array, pd.DataFrame, List[str], List[List[str]], psDataFrame],
X: Union[np.array, DataFrame, List[str], List[List[str]], psDataFrame],
**pred_kwargs,
):
"""Predict label from features.
@ -574,7 +555,7 @@ class AutoML(BaseEstimator):
the searched learners, such as per_device_eval_batch_size.
```python
multivariate_X_test = pd.DataFrame({
multivariate_X_test = DataFrame({
'timeStamp': pd.date_range(start='1/1/2022', end='1/07/2022'),
'categorical_col': ['yes', 'yes', 'no', 'no', 'yes', 'no', 'yes'],
'continuous_col': [105, 107, 120, 118, 110, 112, 115]
@ -596,7 +577,7 @@ class AutoML(BaseEstimator):
if isinstance(y_pred, np.ndarray) and y_pred.ndim > 1 and isinstance(y_pred, np.ndarray):
y_pred = y_pred.flatten()
if self._label_transformer:
return self._label_transformer.inverse_transform(pd.Series(y_pred.astype(int)))
return self._label_transformer.inverse_transform(Series(y_pred.astype(int)))
else:
return y_pred

View File

@ -3,30 +3,16 @@
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
import numpy as np
from scipy.sparse import vstack, issparse
import pandas as pd
from pandas import DataFrame, Series
from flaml.automl.training_log import training_log_reader
from datetime import datetime
from typing import TYPE_CHECKING, Union
import os
from flaml.automl.training_log import training_log_reader
from flaml.automl.spark import ps, psDataFrame, psSeries, DataFrame, Series, pd
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
from scipy.sparse import vstack, issparse
except ImportError:
ps = None
class psDataFrame:
pass
class psSeries:
pass
pass
if TYPE_CHECKING:
from flaml.automl.task import Task
@ -55,7 +41,6 @@ def load_openml_dataset(dataset_id, data_dir=None, random_state=0, dataset_forma
y_train: A series or array of labels for training data.
y_test: A series or array of labels for test data.
"""
import os
import openml
import pickle
from sklearn.model_selection import train_test_split
@ -108,7 +93,6 @@ def load_openml_task(task_id, data_dir):
y_train: A series of labels for training data.
y_test: A series of labels for test data.
"""
import os
import openml
import pickle

View File

@ -2,24 +2,9 @@
# * Copyright (c) FLAML authors. All rights reserved.
# * Licensed under the MIT License. See LICENSE file in the
# * project root for license information.
import os
import time
import numpy as np
import pandas as pd
from typing import Union, Callable, TypeVar, Optional, Tuple
from sklearn.metrics import (
mean_squared_error,
r2_score,
roc_auc_score,
accuracy_score,
mean_absolute_error,
log_loss,
average_precision_score,
f1_score,
mean_absolute_percentage_error,
ndcg_score,
)
from flaml.automl.model import (
XGBoostSklearnEstimator,
XGBoost_TS,
@ -47,27 +32,26 @@ from flaml.automl.model import (
from flaml.automl.data import group_counts
from flaml.automl.task.task import TS_FORECAST, Task
from flaml.automl.model import BaseEstimator
from flaml.automl.spark import psDataFrame, psSeries, ERROR as SPARK_ERROR, Series
try:
from flaml.automl.spark.utils import len_labels
from sklearn.metrics import (
mean_squared_error,
r2_score,
roc_auc_score,
accuracy_score,
mean_absolute_error,
log_loss,
average_precision_score,
f1_score,
mean_absolute_percentage_error,
ndcg_score,
)
except ImportError:
from flaml.automl.utils import len_labels
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.sql.functions import col
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
from flaml.automl.spark.utils import to_pandas_on_spark, iloc_pandas_on_spark
pass
if SPARK_ERROR is None:
from flaml.automl.spark.metrics import spark_metric_loss_score
except ImportError:
ps = None
class psDataFrame:
pass
class psSeries:
pass
EstimatorSubclass = TypeVar("EstimatorSubclass", bound=BaseEstimator)
@ -209,7 +193,7 @@ def metric_loss_score(
y_processed_true = [[labels[tr] for tr in each_list] for each_list in y_processed_true]
elif metric in ("pearsonr", "spearmanr"):
y_processed_true = (
y_processed_true.to_list() if isinstance(y_processed_true, pd.Series) else list(y_processed_true)
y_processed_true.to_list() if isinstance(y_processed_true, Series) else list(y_processed_true)
)
score_dict = metric.compute(predictions=y_processed_predict, references=y_processed_true)
if "rouge" in metric_name:
@ -612,7 +596,7 @@ def train_estimator(
return estimator, train_time
def norm_confusion_matrix(y_true: Union[np.array, pd.Series], y_pred: Union[np.array, pd.Series]):
def norm_confusion_matrix(y_true: Union[np.array, Series], y_pred: Union[np.array, Series]):
"""normalized confusion matrix.
Args:
@ -631,8 +615,8 @@ def norm_confusion_matrix(y_true: Union[np.array, pd.Series], y_pred: Union[np.a
def multi_class_curves(
y_true: Union[np.array, pd.Series],
y_pred_proba: Union[np.array, pd.Series],
y_true: Union[np.array, Series],
y_pred_proba: Union[np.array, Series],
curve_func: Callable,
):
"""Binarize the data for multi-class tasks and produce ROC or precision-recall curves.

View File

@ -9,14 +9,8 @@ import os
from typing import Callable, List, Union
import numpy as np
import time
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier, DummyRegressor
from scipy.sparse import issparse
import logging
import shutil
from pandas import DataFrame, Series, to_datetime
import sys
import math
from flaml import tune
@ -37,36 +31,28 @@ from flaml.automl.task.task import (
)
try:
from flaml.automl.spark.utils import len_labels, to_pandas_on_spark
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.ensemble import ExtraTreesRegressor, ExtraTreesClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.dummy import DummyClassifier, DummyRegressor
except ImportError:
from flaml.automl.utils import len_labels
pass
to_pandas_on_spark = None
try:
from scipy.sparse import issparse
except ImportError:
pass
from flaml.automl.spark import psDataFrame, sparkDataFrame, psSeries, ERROR as SPARK_ERROR, DataFrame, Series
from flaml.automl.spark.utils import len_labels, to_pandas_on_spark
from flaml.automl.spark.configs import (
ParamList_LightGBM_Classifier,
ParamList_LightGBM_Regressor,
ParamList_LightGBM_Ranker,
)
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.sql.dataframe import DataFrame as sparkDataFrame
from pyspark.sql import SparkSession
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
_have_spark = True
except ImportError:
_have_spark = False
class psDataFrame:
pass
class psSeries:
pass
class sparkDataFrame:
pass
if DataFrame is not None:
from pandas import to_datetime
try:
import psutil
@ -415,8 +401,8 @@ class SparkEstimator(BaseEstimator):
"""The base class for fine-tuning spark models, using pyspark.ml and SynapseML API."""
def __init__(self, task="binary", **config):
if not _have_spark:
raise ImportError("pyspark is not installed. Try `pip install flaml[spark]`.")
if SPARK_ERROR:
raise SPARK_ERROR
super().__init__(task, **config)
self.df_train = None

View File

@ -1,8 +1,7 @@
import argparse
from dataclasses import dataclass, field
from flaml.automl.task.task import NLG_TASKS
from typing import Optional, List
from flaml.automl.task.task import NLG_TASKS
try:
from transformers import TrainingArguments

View File

@ -1,7 +1,5 @@
import pandas as pd
from itertools import chain
import numpy as np
from flaml.automl.task.task import (
SUMMARIZATION,
SEQREGRESSION,
@ -10,6 +8,7 @@ from flaml.automl.task.task import (
TOKENCLASSIFICATION,
NLG_TASKS,
)
from flaml.automl.data import pd
def todf(X, Y, column_name):

View File

@ -0,0 +1,32 @@
import os
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
try:
import pyspark
import pyspark.pandas as ps
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import DataFrame as sparkDataFrame
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries, set_option
from pyspark.util import VersionUtils
except ImportError:
class psDataFrame:
pass
F = T = ps = sparkDataFrame = psSeries = psDataFrame
_spark_major_minor_version = set_option = None
ERROR = ImportError(
"""Please run pip install flaml[spark]
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark."""
)
else:
ERROR = None
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
try:
import pandas as pd
from pandas import DataFrame, Series
except ImportError:
DataFrame = Series = pd = None

View File

@ -1,28 +1,16 @@
import logging
import os
import numpy as np
from typing import Union
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.sql import DataFrame
import pyspark.pandas as ps
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
RegressionEvaluator,
MulticlassClassificationEvaluator,
MultilabelClassificationEvaluator,
RankingEvaluator,
)
import pyspark.sql.functions as F
except ImportError:
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark."""
raise ImportError(msg)
from flaml.automl.spark import psSeries, F
from pyspark.ml.evaluation import (
BinaryClassificationEvaluator,
RegressionEvaluator,
MulticlassClassificationEvaluator,
MultilabelClassificationEvaluator,
RankingEvaluator,
)
def ps_group_counts(groups: Union[ps.Series, np.ndarray]) -> np.ndarray:
def ps_group_counts(groups: Union[psSeries, np.ndarray]) -> np.ndarray:
if isinstance(groups, np.ndarray):
_, i, c = np.unique(groups, return_counts=True, return_index=True)
else:
@ -48,20 +36,20 @@ def _compute_label_from_probability(df, probability_col, prediction_col):
def spark_metric_loss_score(
metric_name: str,
y_predict: ps.Series,
y_true: ps.Series,
sample_weight: ps.Series = None,
groups: ps.Series = None,
y_predict: psSeries,
y_true: psSeries,
sample_weight: psSeries = None,
groups: psSeries = None,
) -> float:
"""
Compute the loss score of a metric for spark models.
Args:
metric_name: str | the name of the metric.
y_predict: ps.Series | the predicted values.
y_true: ps.Series | the true values.
sample_weight: ps.Series | the sample weights. Default: None.
groups: ps.Series | the group of each row. Default: None.
y_predict: psSeries | the predicted values.
y_true: psSeries | the true values.
sample_weight: psSeries | the sample weights. Default: None.
groups: psSeries | the group of each row. Default: None.
Returns:
float | the loss score. A lower value indicates a better model.

View File

@ -1,37 +1,31 @@
import logging
import os
from typing import Union, List, Optional, Tuple
import pandas as pd
import numpy as np
from flaml.automl.spark import (
sparkDataFrame,
ps,
F,
T,
psDataFrame,
psSeries,
_spark_major_minor_version,
DataFrame,
Series,
set_option,
)
logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)
logger.propagate = False
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.sql import SparkSession
from pyspark.sql import DataFrame
import pyspark.pandas as ps
from pyspark.util import VersionUtils
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
except ImportError:
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark."""
raise ImportError(msg)
def to_pandas_on_spark(
df: Union[pd.DataFrame, DataFrame, pd.Series, ps.DataFrame, ps.Series],
df: Union[DataFrame, sparkDataFrame, Series, psDataFrame, psSeries],
index_col: Optional[str] = None,
default_index_type: Optional[str] = "distributed-sequence",
) -> Union[ps.DataFrame, ps.Series]:
) -> Union[psDataFrame, psSeries]:
"""Convert pandas or pyspark dataframe/series to pandas_on_Spark dataframe/series.
Args:
@ -46,7 +40,7 @@ def to_pandas_on_spark(
import pandas as pd
from flaml.automl.spark.utils import to_pandas_on_spark
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
pdf = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
psdf = to_pandas_on_spark(pdf)
print(psdf)
@ -57,33 +51,33 @@ def to_pandas_on_spark(
psdf = to_pandas_on_spark(sdf)
print(psdf)
pds = pd.Series([1, 2, 3])
pds = Series([1, 2, 3])
pss = to_pandas_on_spark(pds)
print(pss)
```
"""
ps.set_option("compute.default_index_type", default_index_type)
if isinstance(df, (pd.DataFrame, pd.Series)):
set_option("compute.default_index_type", default_index_type)
if isinstance(df, (DataFrame, Series)):
return ps.from_pandas(df)
elif isinstance(df, DataFrame):
elif isinstance(df, sparkDataFrame):
if _spark_major_minor_version[0] == 3 and _spark_major_minor_version[1] < 3:
return df.to_pandas_on_spark(index_col=index_col)
else:
return df.pandas_api(index_col=index_col)
elif isinstance(df, (ps.DataFrame, ps.Series)):
elif isinstance(df, (psDataFrame, psSeries)):
return df
else:
raise TypeError(f"{type(df)} is not one of pandas.DataFrame, pandas.Series and pyspark.sql.DataFrame")
def train_test_split_pyspark(
df: Union[DataFrame, ps.DataFrame],
df: Union[sparkDataFrame, psDataFrame],
stratify_column: Optional[str] = None,
test_fraction: Optional[float] = 0.2,
seed: Optional[int] = 1234,
to_pandas_spark: Optional[bool] = True,
index_col: Optional[str] = "tmp_index_col",
) -> Tuple[Union[DataFrame, ps.DataFrame], Union[DataFrame, ps.DataFrame]]:
) -> Tuple[Union[sparkDataFrame, psDataFrame], Union[sparkDataFrame, psDataFrame]]:
"""Split a pyspark dataframe into train and test dataframes.
Args:
@ -98,7 +92,7 @@ def train_test_split_pyspark(
pyspark.sql.DataFrame/pandas_on_spark DataFrame | The train dataframe.
pyspark.sql.DataFrame/pandas_on_spark DataFrame | The test dataframe.
"""
if isinstance(df, ps.DataFrame):
if isinstance(df, psDataFrame):
df = df.to_spark(index_col=index_col)
if stratify_column:
@ -123,9 +117,9 @@ def train_test_split_pyspark(
return [df_train, df_test]
def unique_pandas_on_spark(psds: Union[ps.Series, ps.DataFrame]) -> Tuple[np.ndarray, np.ndarray]:
def unique_pandas_on_spark(psds: Union[psSeries, psDataFrame]) -> Tuple[np.ndarray, np.ndarray]:
"""Get the unique values and counts of a pandas_on_spark series."""
if isinstance(psds, ps.DataFrame):
if isinstance(psds, psDataFrame):
psds = psds.iloc[:, 0]
_tmp = psds.value_counts().to_pandas()
label_set = _tmp.index.values
@ -133,21 +127,21 @@ def unique_pandas_on_spark(psds: Union[ps.Series, ps.DataFrame]) -> Tuple[np.nda
return label_set, counts
def len_labels(y: Union[ps.Series, np.ndarray], return_labels=False) -> Union[int, Optional[np.ndarray]]:
def len_labels(y: Union[psSeries, np.ndarray], return_labels=False) -> Union[int, Optional[np.ndarray]]:
"""Get the number of unique labels in y."""
if not isinstance(y, (ps.DataFrame, ps.Series)):
if not isinstance(y, (psDataFrame, psSeries)):
labels = np.unique(y)
else:
labels = y.unique() if isinstance(y, ps.Series) else y.iloc[:, 0].unique()
labels = y.unique() if isinstance(y, psSeries) else y.iloc[:, 0].unique()
if return_labels:
return len(labels), labels
return len(labels)
def unique_value_first_index(y: Union[pd.Series, ps.Series, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
def unique_value_first_index(y: Union[Series, psSeries, np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
"""Get the unique values and indices of a pandas series,
pandas_on_spark series or numpy array."""
if isinstance(y, ps.Series):
if isinstance(y, psSeries):
y_unique = y.drop_duplicates().sort_index()
label_set = y_unique.values
first_index = y_unique.index.values
@ -157,20 +151,20 @@ def unique_value_first_index(y: Union[pd.Series, ps.Series, np.ndarray]) -> Tupl
def iloc_pandas_on_spark(
psdf: Union[ps.DataFrame, ps.Series, pd.DataFrame, pd.Series],
psdf: Union[psDataFrame, psSeries, DataFrame, Series],
index: Union[int, slice, list],
index_col: Optional[str] = "tmp_index_col",
) -> Union[ps.DataFrame, ps.Series]:
) -> Union[psDataFrame, psSeries]:
"""Get the rows of a pandas_on_spark dataframe/series by index."""
if isinstance(psdf, (pd.DataFrame, pd.Series)):
if isinstance(psdf, (DataFrame, Series)):
return psdf.iloc[index]
if isinstance(index, (int, slice)):
if isinstance(psdf, ps.Series):
if isinstance(psdf, psSeries):
return psdf.iloc[index]
else:
return psdf.iloc[index, :]
elif isinstance(index, list):
if isinstance(psdf, ps.Series):
if isinstance(psdf, psSeries):
sdf = psdf.to_frame().to_spark(index_col=index_col)
else:
if index_col not in psdf.columns:
@ -179,7 +173,7 @@ def iloc_pandas_on_spark(
sdf = psdf.to_spark()
sdfiloc = sdf.filter(F.col(index_col).isin(index))
psdfiloc = to_pandas_on_spark(sdfiloc)
if isinstance(psdf, ps.Series):
if isinstance(psdf, psSeries):
psdfiloc = psdfiloc[psdfiloc.columns.drop(index_col)[0]]
elif index_col not in psdf.columns:
psdfiloc = psdfiloc.drop(columns=[index_col])
@ -189,17 +183,17 @@ def iloc_pandas_on_spark(
def spark_kFold(
dataset: Union[DataFrame, ps.DataFrame],
dataset: Union[sparkDataFrame, psDataFrame],
nFolds: int = 3,
foldCol: str = "",
seed: int = 42,
index_col: Optional[str] = "tmp_index_col",
) -> List[Tuple[ps.DataFrame, ps.DataFrame]]:
) -> List[Tuple[psDataFrame, psDataFrame]]:
"""Generate k-fold splits for a Spark DataFrame.
Adopted from https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/tuning.html#CrossValidator
Args:
dataset: DataFrame / ps.DataFrame. | The DataFrame to split.
dataset: sparkDataFrame / psDataFrame. | The DataFrame to split.
nFolds: int | The number of folds. Default is 3.
foldCol: str | The column name to use for fold numbers. If not specified,
the DataFrame will be randomly split. Default is "".
@ -213,7 +207,7 @@ def spark_kFold(
Returns:
A list of (train, validation) DataFrames.
"""
if isinstance(dataset, ps.DataFrame):
if isinstance(dataset, psDataFrame):
dataset = dataset.to_spark(index_col=index_col)
datasets = []

View File

@ -1,43 +1,12 @@
import inspect
import time
import os
from typing import Any, Optional
import numpy as np
import pandas as pd
from flaml import tune
from flaml.automl.logger import logger
from flaml.automl.ml import compute_estimator, train_estimator
from flaml.automl.task.task import TS_FORECAST
try:
from flaml.automl.spark.utils import (
train_test_split_pyspark,
unique_pandas_on_spark,
len_labels,
unique_value_first_index,
)
except ImportError:
train_test_split_pyspark = None
unique_pandas_on_spark = None
from flaml.automl.utils import (
len_labels,
unique_value_first_index,
)
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
from pyspark.pandas.config import set_option, reset_option
except ImportError:
ps = None
class psDataFrame:
pass
class psSeries:
pass
from flaml.automl.spark import psDataFrame, psSeries, DataFrame, Series
class SearchState:
@ -245,11 +214,11 @@ class AutoMLState:
def _prepare_sample_train_data(self, sample_size: int):
sampled_weight = groups = None
if sample_size <= self.data_size[0]:
if isinstance(self.X_train, (pd.DataFrame, psDataFrame)):
if isinstance(self.X_train, (DataFrame, psDataFrame)):
sampled_X_train = self.X_train.iloc[:sample_size]
else:
sampled_X_train = self.X_train[:sample_size]
if isinstance(self.y_train, (pd.Series, psSeries)):
if isinstance(self.y_train, (Series, psSeries)):
sampled_y_train = self.y_train.iloc[:sample_size]
else:
sampled_y_train = self.y_train[:sample_size]
@ -258,12 +227,12 @@ class AutoMLState:
) # NOTE: _prepare_sample_train_data is before kwargs is updated to fit_kwargs_by_estimator
if weight is not None:
sampled_weight = (
weight.iloc[:sample_size] if isinstance(weight, (pd.Series, psSeries)) else weight[:sample_size]
weight.iloc[:sample_size] if isinstance(weight, (Series, psSeries)) else weight[:sample_size]
)
if self.groups is not None:
groups = (
self.groups.iloc[:sample_size]
if isinstance(self.groups, (pd.Series, psSeries))
if isinstance(self.groups, (Series, psSeries))
else self.groups[:sample_size]
)
else:

View File

@ -1,15 +1,13 @@
from typing import Optional, Union
import numpy as np
import pandas as pd
from flaml.automl.task.generic_task import GenericTask
from flaml.automl.task.task import Task
from flaml.automl.data import DataFrame, Series
def task_factory(
task_name: str,
X_train: Optional[Union[np.ndarray, pd.DataFrame]] = None,
y_train: Optional[Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
X_train: Optional[Union[np.ndarray, DataFrame]] = None,
y_train: Optional[Union[np.ndarray, DataFrame, Series]] = None,
) -> Task:
return GenericTask(task_name, X_train, y_train)

View File

@ -1,22 +1,7 @@
import os
import logging
import time
from typing import List, Optional
import pandas as pd
import numpy as np
from scipy.sparse import issparse
from sklearn.utils import shuffle
from sklearn.model_selection import (
train_test_split,
RepeatedStratifiedKFold,
RepeatedKFold,
GroupKFold,
TimeSeriesSplit,
GroupShuffleSplit,
StratifiedGroupKFold,
)
from flaml.automl.data import TS_TIMESTAMP_COL, concat
from flaml.automl.ml import EstimatorSubclass, default_cv_score_agg_func, get_val_loss
from flaml.automl.model import (
@ -40,40 +25,34 @@ from flaml.automl.task.task import (
TS_FORECASTPANEL,
)
from flaml.config import RANDOM_SEED
from flaml.automl.spark import ps, psDataFrame, psSeries, pd
from flaml.automl.spark.utils import (
iloc_pandas_on_spark,
spark_kFold,
train_test_split_pyspark,
unique_pandas_on_spark,
unique_value_first_index,
len_labels,
set_option,
)
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
from pyspark.sql.functions import col
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
from pyspark.pandas.config import set_option, reset_option
from flaml.automl.spark.utils import (
to_pandas_on_spark,
iloc_pandas_on_spark,
spark_kFold,
train_test_split_pyspark,
unique_pandas_on_spark,
unique_value_first_index,
len_labels,
)
from flaml.automl.spark.metrics import spark_metric_loss_score
from scipy.sparse import issparse
except ImportError:
train_test_split_pyspark = None
unique_pandas_on_spark = None
iloc_pandas_on_spark = None
from flaml.automl.utils import (
len_labels,
unique_value_first_index,
pass
try:
from sklearn.utils import shuffle
from sklearn.model_selection import (
train_test_split,
RepeatedStratifiedKFold,
RepeatedKFold,
GroupKFold,
TimeSeriesSplit,
GroupShuffleSplit,
StratifiedGroupKFold,
)
ps = None
class psDataFrame:
pass
class psSeries:
pass
except ImportError:
pass
logger = logging.getLogger(__name__)

View File

@ -1,17 +1,11 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from flaml.automl.data import DataFrame, Series, psDataFrame, psSeries
if TYPE_CHECKING:
import flaml
try:
import ray
except ImportError:
ray = None
# TODO: if your task is not specified in here, define your task as an all-capitalized word
SEQCLASSIFICATION = "seq-classification"
MULTICHOICECLASSIFICATION = "multichoice-classification"
@ -80,8 +74,8 @@ class Task(ABC):
def __init__(
self,
task_name: str,
X_train: Optional[Union[np.ndarray, pd.DataFrame]] = None,
y_train: Optional[Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
X_train: Optional[Union[np.ndarray, DataFrame, psDataFrame]] = None,
y_train: Optional[Union[np.ndarray, DataFrame, Series, psSeries]] = None,
):
"""Constructor.
@ -104,8 +98,8 @@ class Task(ABC):
self,
config: dict,
estimator: "flaml.automl.ml.BaseEstimator",
X_train_all: Union[np.ndarray, pd.DataFrame],
y_train_all: Union[np.ndarray, pd.DataFrame, pd.Series],
X_train_all: Union[np.ndarray, DataFrame, psDataFrame],
y_train_all: Union[np.ndarray, DataFrame, Series, psSeries],
budget: int,
kf,
eval_metric: str,
@ -136,12 +130,12 @@ class Task(ABC):
self,
automl: "flaml.automl.automl.AutoML",
state: "flaml.automl.state.AutoMLState",
X_train_all: Union[np.ndarray, pd.DataFrame, None],
y_train_all: Union[np.ndarray, pd.DataFrame, pd.Series, None],
dataframe: Union[pd.DataFrame, None],
X_train_all: Union[np.ndarray, DataFrame, psDataFrame, None],
y_train_all: Union[np.ndarray, DataFrame, Series, psSeries, None],
dataframe: Union[DataFrame, None],
label: str,
X_val: Optional[Union[np.ndarray, pd.DataFrame]] = None,
y_val: Optional[Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
X_val: Optional[Union[np.ndarray, DataFrame, psDataFrame]] = None,
y_val: Optional[Union[np.ndarray, DataFrame, Series, psSeries]] = None,
groups_val: Optional[List[str]] = None,
groups: Optional[List[str]] = None,
):
@ -169,8 +163,8 @@ class Task(ABC):
def prepare_data(
self,
state: "flaml.automl.state.AutoMLState",
X_train_all: Union[np.ndarray, pd.DataFrame],
y_train_all: Union[np.ndarray, pd.DataFrame, pd.Series, None],
X_train_all: Union[np.ndarray, DataFrame, psDataFrame],
y_train_all: Union[np.ndarray, DataFrame, Series, psSeries, None],
auto_augment: bool,
eval_method: str,
split_type: str,
@ -203,7 +197,7 @@ class Task(ABC):
For ranking task, must be "auto" or 'group'.
split_ratio: A float of the valiation data percentage for holdout.
n_splits: An integer of the number of folds for cross - validation.
data_is_df: True if the data was provided as a pd.DataFrame else False.
data_is_df: True if the data was provided as a DataFrame else False.
sample_weight_full: A 1d arraylike of the sample weight.
Raises:
@ -214,7 +208,7 @@ class Task(ABC):
def decide_split_type(
self,
split_type: str,
y_train_all: Union[np.ndarray, pd.DataFrame, pd.Series, None],
y_train_all: Union[np.ndarray, DataFrame, Series, psSeries, None],
fit_kwargs: dict,
groups: Optional[List[str]] = None,
) -> str:
@ -240,9 +234,9 @@ class Task(ABC):
@abstractmethod
def preprocess(
self,
X: Union[np.ndarray, pd.DataFrame],
X: Union[np.ndarray, DataFrame, psDataFrame],
transformer: Optional["flaml.automl.data.DataTransformer"] = None,
) -> Union[np.ndarray, pd.DataFrame]:
) -> Union[np.ndarray, DataFrame]:
"""Preprocess the data ready for fitting or inference with this task type.
Args:

View File

@ -1,18 +0,0 @@
from typing import Optional, Union, Tuple
import numpy as np
def len_labels(y: np.ndarray, return_labels=False) -> Union[int, Optional[np.ndarray]]:
"""Get the number of unique labels in y. The non-spark version of
flaml.automl.spark.utils.len_labels"""
labels = np.unique(y)
if return_labels:
return len(labels), labels
return len(labels)
def unique_value_first_index(y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Get the unique values and indices of a pandas series or numpy array.
The non-spark version of flaml.automl.spark.utils.unique_value_first_index"""
label_set, first_index = np.unique(y, return_index=True)
return label_set, first_index

View File

@ -1,5 +1,5 @@
"""!
* Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved.
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
"""

View File

@ -1,4 +1,3 @@
import sklearn.ensemble as ensemble
from functools import wraps
from flaml.automl.task.task import CLASSIFICATION
from .suggest import preprocess_and_suggest_hyperparams
@ -143,22 +142,31 @@ def flamlize_estimator(super_class, name: str, task: str, alternatives=None):
return EstimatorClass
RandomForestRegressor = flamlize_estimator(ensemble.RandomForestRegressor, "rf", "regression")
RandomForestClassifier = flamlize_estimator(ensemble.RandomForestClassifier, "rf", "classification")
ExtraTreesRegressor = flamlize_estimator(ensemble.ExtraTreesRegressor, "extra_tree", "regression")
ExtraTreesClassifier = flamlize_estimator(ensemble.ExtraTreesClassifier, "extra_tree", "classification")
try:
import sklearn.ensemble as ensemble
except ImportError:
RandomForestClassifier = RandomForestRegressor = ExtraTreesClassifier = ExtraTreesRegressor = ImportError(
"Using flaml.default.* requires scikit-learn."
)
else:
RandomForestRegressor = flamlize_estimator(ensemble.RandomForestRegressor, "rf", "regression")
RandomForestClassifier = flamlize_estimator(ensemble.RandomForestClassifier, "rf", "classification")
ExtraTreesRegressor = flamlize_estimator(ensemble.ExtraTreesRegressor, "extra_tree", "regression")
ExtraTreesClassifier = flamlize_estimator(ensemble.ExtraTreesClassifier, "extra_tree", "classification")
try:
import lightgbm
except ImportError:
LGBMRegressor = LGBMClassifier = ImportError("Using flaml.default.LGBM* requires lightgbm.")
else:
LGBMRegressor = flamlize_estimator(lightgbm.LGBMRegressor, "lgbm", "regression")
LGBMClassifier = flamlize_estimator(lightgbm.LGBMClassifier, "lgbm", "classification")
except ImportError:
pass
try:
import xgboost
except ImportError:
XGBClassifier = XGBRegressor = ImportError("Using flaml.default.XGB* requires xgboost.")
else:
XGBRegressor = flamlize_estimator(
xgboost.XGBRegressor,
"xgb_limitdepth",
@ -171,5 +179,3 @@ try:
"classification",
[("max_depth", 0, "xgboost")],
)
except ImportError:
pass

View File

@ -1,41 +1,23 @@
import os
import numpy as np
from sklearn.neighbors import NearestNeighbors
import logging
import pathlib
import json
from flaml.automl.data import DataTransformer
from flaml.automl.task.task import CLASSIFICATION, get_classification_objective
from flaml.automl.task.generic_task import len_labels
from flaml.automl.ml import get_estimator_class
from flaml.version import __version__
try:
from flaml.automl.spark.utils import len_labels
from sklearn.neighbors import NearestNeighbors
except ImportError:
from flaml.automl.utils import len_labels
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pyspark.pandas as ps
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
except ImportError:
ps = None
class psDataFrame:
pass
class psSeries:
pass
pass
LOCATION = pathlib.Path(__file__).parent.resolve()
logger = logging.getLogger(__name__)
CONFIG_PREDICTORS = {}
def version_parse(version):
return tuple(map(int, (version.split("."))))
def meta_feature(task, X_train, y_train, meta_feature_names):
this_feature = []
n_row = X_train.shape[0]
@ -94,6 +76,8 @@ def suggest_config(
The returned configs can be used as starting points for AutoML.fit().
`FLAML_sample_size` is removed from the configs.
"""
from packaging.version import parse as version_parse
task = get_classification_objective(len_labels(y)) if task == "classification" and y is not None else task
predictor = (
load_config_predictor(estimator_or_predictor, task, location)

View File

@ -5,9 +5,13 @@ import math
import copy
import collections
from typing import Optional, Union
from sklearn.metrics import mean_squared_error, mean_absolute_error
from flaml.tune import Trial
try:
from sklearn.metrics import mean_squared_error, mean_absolute_error
except ImportError:
pass
logger = logging.getLogger(__name__)

View File

@ -958,9 +958,7 @@ try:
from nni.tuner import Tuner as NNITuner
from nni.utils import extract_scalar_reward
except ImportError:
class NNITuner:
pass
NNITuner = object
def extract_scalar_reward(x: Dict):
return x.get("default")

View File

@ -11,20 +11,19 @@ logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)
logger.propagate = False
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
try:
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pyspark
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
import py4j
_have_spark = True
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
except ImportError as e:
logger.debug("Could not import pyspark: %s", e)
except ImportError:
_have_spark = False
py4j = None
_spark_major_minor_version = (0, 0)
else:
_have_spark = True
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
@lru_cache(maxsize=2)
@ -37,7 +36,7 @@ def check_spark():
Return (True, None) if the check passes, otherwise log the exception message and
return (False, Exception(msg)). The exception can be raised by the caller.
"""
logger.debug("\ncheck Spark installation...This line should appear only once.\n")
logger.debug("\nchecking Spark installation...This line should appear only once.\n")
if not _have_spark:
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
@ -51,7 +50,6 @@ def check_spark():
try:
SparkSession.builder.getOrCreate()
except RuntimeError as e:
# logger.warning(f"\nSparkSession is not available: {e}\n")
return False, RuntimeError(e)
return True, None

View File

@ -15,16 +15,16 @@ try:
assert ray_version >= "1.10.0"
from ray.tune.analysis import ExperimentAnalysis as EA
ray_available = True
except (ImportError, AssertionError):
ray_available = False
from .analysis import ExperimentAnalysis as EA
else:
ray_available = True
from .trial import Trial
from .result import DEFAULT_METRIC
import logging
from flaml.tune.spark.utils import PySparkOvertimeMonitor
from flaml.tune.spark.utils import PySparkOvertimeMonitor, check_spark
logger = logging.getLogger(__name__)
logger.propagate = False
@ -231,7 +231,7 @@ def run(
n_concurrent_trials: Optional[int] = 0,
**ray_args,
):
"""The trigger for HPO.
"""The function-based way of performing HPO.
Example:
@ -612,8 +612,6 @@ def run(
if use_spark:
# parallel run with spark
from flaml.tune.spark.utils import check_spark
spark_available, spark_error_msg = check_spark()
if not spark_available:
raise spark_error_msg
@ -811,3 +809,84 @@ def run(
_runner = old_runner
logger.handlers = old_handlers
logger.setLevel(old_level)
class Tuner:
"""Tuner is the class-based way of launching hyperparameter tuning jobs compatible with Ray Tune 2.
Args:
trainable: A user-defined evaluation function.
It takes a configuration as input, outputs a evaluation
result (can be a numerical value or a dictionary of string
and numerical value pairs) for the input configuration.
For machine learning tasks, it usually involves training and
scoring a machine learning model, e.g., through validation loss.
param_space: Search space of the tuning job.
One thing to note is that both preprocessor and dataset can be tuned here.
tune_config: Tuning algorithm specific configs.
Refer to ray.tune.tune_config.TuneConfig for more info.
run_config: Runtime configuration that is specific to individual trials.
If passed, this will overwrite the run config passed to the Trainer,
if applicable. Refer to ray.air.config.RunConfig for more info.
Usage pattern:
.. code-block:: python
from sklearn.datasets import load_breast_cancer
from ray import tune
from ray.data import from_pandas
from ray.air.config import RunConfig, ScalingConfig
from ray.train.xgboost import XGBoostTrainer
from ray.tune.tuner import Tuner
def get_dataset():
data_raw = load_breast_cancer(as_frame=True)
dataset_df = data_raw["data"]
dataset_df["target"] = data_raw["target"]
dataset = from_pandas(dataset_df)
return dataset
trainer = XGBoostTrainer(
label_column="target",
params={},
datasets={"train": get_dataset()},
)
param_space = {
"scaling_config": ScalingConfig(
num_workers=tune.grid_search([2, 4]),
resources_per_worker={
"CPU": tune.grid_search([1, 2]),
},
),
# You can even grid search various datasets in Tune.
# "datasets": {
# "train": tune.grid_search(
# [ds1, ds2]
# ),
# },
"params": {
"objective": "binary:logistic",
"tree_method": "approx",
"eval_metric": ["logloss", "error"],
"eta": tune.loguniform(1e-4, 1e-1),
"subsample": tune.uniform(0.5, 1.0),
"max_depth": tune.randint(1, 9),
},
}
tuner = Tuner(trainable=trainer, param_space=param_space,
run_config=RunConfig(name="my_tune_run"))
analysis = tuner.fit()
To retry a failed tune run, you can then do
.. code-block:: python
tuner = Tuner.restore(experiment_checkpoint_dir)
tuner.fit()
``experiment_checkpoint_dir`` can be easily located near the end of the
console output of your first failed run.
"""

View File

@ -1 +1 @@
__version__ = "1.2.4"
__version__ = "2.0.0rc1"

View File

@ -1,5 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_chatgpt_gpt4.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -23,7 +31,7 @@
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [openai,blendsearch] option:\n",
"```bash\n",
"pip install flaml[openai,blendsearch]==1.2.2\n",
"pip install flaml[openai,blendsearch]\n",
"```"
]
},
@ -40,7 +48,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[openai,blendsearch]==1.2.2 datasets"
"# %pip install flaml[openai,blendsearch] datasets"
]
},
{

View File

@ -1,5 +1,13 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/autogen_openai.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -23,7 +31,7 @@
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [autogen,blendsearch] option:\n",
"```bash\n",
"pip install flaml[autogen,blendsearch]==1.2.2\n",
"pip install flaml[autogen,blendsearch]\n",
"```"
]
},
@ -40,7 +48,7 @@
},
"outputs": [],
"source": [
"# %pip install flaml[autogen,blendsearch]==1.2.2 datasets"
"# %pip install flaml[autogen,blendsearch] datasets"
]
},
{

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"slideshow": {
@ -27,9 +28,9 @@
"\n",
"In this notebook, we demonstrate how to use FLAML library to tune hyperparameters of LightGBM with a regression example.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `notebook` option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `automl` option (this option is introduced from version 2, for version 1 it is installed by default):\n",
"```bash\n",
"pip install flaml[notebook]\n",
"pip install flaml[automl]\n",
"```"
]
},
@ -39,7 +40,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install flaml[notebook]==1.0.10"
"%pip install flaml[automl] matplotlib openml"
]
},
{
@ -786,11 +787,6 @@
"model = lgb.train(params, dtrain, valid_sets=[dtrain, dval], verbose_eval=10000) \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 20,

File diff suppressed because one or more lines are too long

View File

@ -25,7 +25,7 @@
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `synapse` option:\n",
"```bash\n",
"pip install flaml[synapse]>=1.1.3; \n",
"pip install flaml[synapse] \n",
"```\n",
" "
]
@ -36,7 +36,7 @@
"metadata": {},
"outputs": [],
"source": [
"# %pip install \"flaml[synapse]>=1.1.3\""
"# %pip install \"flaml[synapse]\""
]
},
{

View File

@ -8,6 +8,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -21,7 +22,7 @@
"\n",
"In this notebook, we demonstrate how to use FLAML library for time series forecasting tasks: univariate time series forecasting (only time), multivariate time series forecasting (with exogneous variables) and forecasting discrete values.\n",
"\n",
"FLAML requires Python>=3.7. To run this notebook example, please install flaml with the notebook and forecast option:\n"
"FLAML requires Python>=3.7. To run this notebook example, please install flaml with the [automl,ts_forecast] option:\n"
]
},
{
@ -156,7 +157,7 @@
}
],
"source": [
"%pip install flaml[notebook,ts_forecast]==1.1.2\n",
"%pip install flaml[automl,ts_forecast] matplotlib openml\n",
"# avoid version 1.0.2 to 1.0.5 for this notebook due to a bug for arima and sarimax's init config"
]
},

View File

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"slideshow": {
@ -27,9 +28,9 @@
"\n",
"In this notebook, we demonstrate how to use FLAML library to tune hyperparameters of XGBoost with a regression example.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `notebook` option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `automl` option (this option is introduced from version 2, for version 1 it is installed by default):\n",
"```bash\n",
"pip install flaml[notebook]==1.1.2\n",
"pip install flaml[automl]\n",
"```"
]
},
@ -39,7 +40,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install flaml[notebook]==1.1.2"
"%pip install flaml[automl] matplotlib openml"
]
},
{

View File

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"slideshow": {
@ -27,9 +28,9 @@
"\n",
"In this notebook, we use one real data example (binary classification) to showcase how to use FLAML library together with AzureML.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [azureml] option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the [automl,azureml] option:\n",
"```bash\n",
"pip install flaml[azureml]\n",
"pip install flaml[automl,azureml]\n",
"```"
]
},
@ -39,7 +40,7 @@
"metadata": {},
"outputs": [],
"source": [
"%pip install flaml[azureml]"
"%pip install flaml[automl,azureml]"
]
},
{

View File

@ -21,6 +21,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -39,12 +40,21 @@
"\n",
"In this notebook, we use one real data example (binary classification) to showcase how to use FLAML library.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `notebook` option:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `[automl]` option (this option is introduced from version 2, for version 1 it is installed by default):\n",
"```bash\n",
"pip install flaml[notebook]\n",
"pip install flaml[automl]\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"%pip install flaml[automl] openml"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -72,15 +82,6 @@
"#### As FLAML's AutoML module can be used a transformer in the Sklearn's pipeline we can get all the benefits of pipeline and thereby write extremley clean, and resuable code."
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"%pip install flaml[notebook]"
]
},
{
"cell_type": "markdown",
"metadata": {},

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
@ -22,7 +23,7 @@
"\n",
"*Running this notebook takes about one hour.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `notebook` and `nlp` options:\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the legacy `[nlp]` options:\n",
"\n",
"```bash\n",
"pip install flaml[nlp]==0.7.1 # in higher version of flaml, the API for nlp tasks changed\n",
@ -362,10 +363,10 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50964)\u001B[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
"\u001B[2m\u001B[36m(pid=50948)\u001B[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001b[2m\u001b[36m(pid=50964)\u001b[0m {'eval_loss': 0.5942569971084595, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10434782608695652}\n",
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n",
"\u001b[2m\u001b[36m(pid=50948)\u001b[0m {'eval_loss': 0.649192214012146, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.2}\n"
]
},
{
@ -483,12 +484,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54411)\u001B[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=54417)\u001B[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54411)\u001b[0m {'eval_loss': 0.624100387096405, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=54417)\u001b[0m {'eval_loss': 0.5938675999641418, 'eval_accuracy': 0.7156862745098039, 'eval_f1': 0.8258258258258258, 'epoch': 0.5}\n"
]
},
{
@ -588,18 +589,18 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57835)\u001B[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57836)\u001B[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=57839)\u001B[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57835)\u001b[0m {'eval_loss': 0.5822290778160095, 'eval_accuracy': 0.7058823529411765, 'eval_f1': 0.8181818181818181, 'epoch': 0.5043478260869565}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57836)\u001b[0m {'eval_loss': 0.6087244749069214, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.10344827586206896}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=57839)\u001b[0m {'eval_loss': 0.5486209392547607, 'eval_accuracy': 0.7034313725490197, 'eval_f1': 0.8141321044546851, 'epoch': 0.5}\n"
]
},
{
@ -699,21 +700,21 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61251)\u001B[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61255)\u001B[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001B[2m\u001B[36m(pid=61236)\u001B[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61251)\u001b[0m {'eval_loss': 0.6236899495124817, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.5}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61255)\u001b[0m {'eval_loss': 0.6249027848243713, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.3}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n",
"\u001b[2m\u001b[36m(pid=61236)\u001b[0m {'eval_loss': 0.6138392686843872, 'eval_accuracy': 0.6838235294117647, 'eval_f1': 0.8122270742358079, 'epoch': 0.20689655172413793}\n"
]
},
{

View File

@ -1,6 +1,15 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/zeroshot_lightgbm.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"slideshow": {
@ -19,16 +28,16 @@
"\n",
"In this notebook, we demonstrate a basic use case of zero-shot AutoML with FLAML.\n",
"\n",
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml and openml:"
"FLAML requires `Python>=3.7`. To run this notebook example, please install the [autozero] option:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# %pip install -U flaml openml;"
"# %pip install flaml[autozero] lightgbm openml;"
]
},
{
@ -51,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
@ -80,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -101,7 +110,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {
"slideshow": {
"slide_type": "subslide"
@ -113,7 +122,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"load dataset from ./openml_ds537.pkl\n",
"download dataset from openml\n",
"Dataset name: houses\n",
"X_train.shape: (15480, 8), y_train.shape: (15480,);\n",
"X_test.shape: (5160, 8), y_test.shape: (5160,)\n"
@ -127,25 +136,38 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" median_income housing_median_age ... latitude longitude\n",
"19226 7.3003 19.0 ... 38.46 -122.68\n",
"14549 5.9547 18.0 ... 32.95 -117.24\n",
"9093 3.2125 19.0 ... 34.68 -118.27\n",
"12213 6.9930 13.0 ... 33.51 -117.18\n",
"12765 2.5162 21.0 ... 38.62 -121.41\n",
"... ... ... ... ... ...\n",
"13123 4.4125 20.0 ... 38.27 -121.26\n",
"19648 2.9135 27.0 ... 37.48 -120.89\n",
"9845 3.1977 31.0 ... 36.58 -121.90\n",
"10799 5.6315 34.0 ... 33.62 -117.93\n",
"2732 1.3882 15.0 ... 32.80 -115.56\n",
" median_income housing_median_age total_rooms total_bedrooms \\\n",
"19226 7.3003 19 4976.0 711.0 \n",
"14549 5.9547 18 1591.0 268.0 \n",
"9093 3.2125 19 552.0 129.0 \n",
"12213 6.9930 13 270.0 42.0 \n",
"12765 2.5162 21 3260.0 763.0 \n",
"... ... ... ... ... \n",
"13123 4.4125 20 1314.0 229.0 \n",
"19648 2.9135 27 1118.0 195.0 \n",
"9845 3.1977 31 1431.0 370.0 \n",
"10799 5.6315 34 2125.0 498.0 \n",
"2732 1.3882 15 1171.0 328.0 \n",
"\n",
" population households latitude longitude \n",
"19226 1926.0 625.0 38.46 -122.68 \n",
"14549 547.0 243.0 32.95 -117.24 \n",
"9093 314.0 106.0 34.68 -118.27 \n",
"12213 120.0 42.0 33.51 -117.18 \n",
"12765 1735.0 736.0 38.62 -121.41 \n",
"... ... ... ... ... \n",
"13123 712.0 219.0 38.27 -121.26 \n",
"19648 647.0 209.0 37.48 -120.89 \n",
"9845 704.0 393.0 36.58 -121.90 \n",
"10799 1052.0 468.0 33.62 -117.93 \n",
"2732 1024.0 298.0 32.80 -115.56 \n",
"\n",
"[15480 rows x 8 columns]\n"
]
@ -168,7 +190,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -176,6 +198,13 @@
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:flaml.default.suggest:metafeature distance: 0.02197989436019765\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -206,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -220,7 +249,7 @@
"0.8537444671194614"
]
},
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@ -238,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -251,7 +280,7 @@
"0.8296179648694404"
]
},
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
@ -309,9 +338,16 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:flaml.default.suggest:metafeature distance: 0.02197989436019765\n"
]
},
{
"name": "stdout",
"output_type": "stream",
@ -341,9 +377,17 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:flaml.default.suggest:metafeature distance: 0.02197989436019765\n"
]
}
],
"source": [
"from flaml.default import preprocess_and_suggest_hyperparams\n",
"(\n",
@ -365,7 +409,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -394,7 +438,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 13,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -415,7 +459,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 14,
"metadata": {
"slideshow": {
"slide_type": "slide"
@ -425,6 +469,17 @@
"outputs": [
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>LGBMRegressor(colsample_bytree=0.7019911744574896,\n",
" learning_rate=0.022635758411078528, max_bin=511,\n",
" min_child_samples=2, n_estimators=4797, num_leaves=122,\n",
" reg_alpha=0.004252223402511765, reg_lambda=0.11288241427227624,\n",
" verbose=-1)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">LGBMRegressor</label><div class=\"sk-toggleable__content\"><pre>LGBMRegressor(colsample_bytree=0.7019911744574896,\n",
" learning_rate=0.022635758411078528, max_bin=511,\n",
" min_child_samples=2, n_estimators=4797, num_leaves=122,\n",
" reg_alpha=0.004252223402511765, reg_lambda=0.11288241427227624,\n",
" verbose=-1)</pre></div></div></div></div></div>"
],
"text/plain": [
"LGBMRegressor(colsample_bytree=0.7019911744574896,\n",
" learning_rate=0.022635758411078528, max_bin=511,\n",
@ -433,7 +488,7 @@
" verbose=-1)"
]
},
"execution_count": 17,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -451,7 +506,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -480,35 +535,45 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[flaml.automl.logger: 04-28 02:51:45] {1663} INFO - task = regression\n",
"[flaml.automl.logger: 04-28 02:51:45] {1670} INFO - Data split method: uniform\n",
"[flaml.automl.logger: 04-28 02:51:45] {1673} INFO - Evaluation method: cv\n",
"[flaml.automl.logger: 04-28 02:51:45] {1771} INFO - Minimizing error metric: 1-r2\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"[flaml.automl: 05-31 22:54:25] {2373} INFO - task = regression\n",
"[flaml.automl: 05-31 22:54:25] {2375} INFO - Data split method: uniform\n",
"[flaml.automl: 05-31 22:54:25] {2379} INFO - Evaluation method: cv\n",
"[flaml.automl: 05-31 22:54:25] {2448} INFO - Minimizing error metric: 1-r2\n",
"[flaml.automl: 05-31 22:54:25] {2586} INFO - List of ML learners in AutoML Run: ['lgbm']\n",
"[flaml.automl: 05-31 22:54:25] {2878} INFO - iteration 0, current learner lgbm\n",
"[flaml.automl: 05-31 22:56:54] {3008} INFO - Estimated sufficient time budget=1490299s. Estimated necessary time budget=1490s.\n",
"[flaml.automl: 05-31 22:56:54] {3055} INFO - at 149.1s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl: 05-31 22:56:54] {2878} INFO - iteration 1, current learner lgbm\n",
"[flaml.automl: 05-31 22:59:24] {3055} INFO - at 299.0s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl: 05-31 22:59:24] {2878} INFO - iteration 2, current learner lgbm\n",
"[flaml.automl: 05-31 23:01:34] {3055} INFO - at 429.1s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl: 05-31 23:01:34] {2878} INFO - iteration 3, current learner lgbm\n",
"[flaml.automl: 05-31 23:04:43] {3055} INFO - at 618.2s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl: 05-31 23:05:14] {3315} INFO - retrain lgbm for 31.0s\n",
"[flaml.automl: 05-31 23:05:14] {3322} INFO - retrained model: LGBMRegressor(colsample_bytree=0.7019911744574896,\n",
"INFO:flaml.default.suggest:metafeature distance: 0.02197989436019765\n",
"INFO:flaml.default.suggest:metafeature distance: 0.006677018633540373\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[flaml.automl.logger: 04-28 02:51:45] {1881} INFO - List of ML learners in AutoML Run: ['lgbm']\n",
"[flaml.automl.logger: 04-28 02:51:45] {2191} INFO - iteration 0, current learner lgbm\n",
"[flaml.automl.logger: 04-28 02:53:39] {2317} INFO - Estimated sufficient time budget=1134156s. Estimated necessary time budget=1134s.\n",
"[flaml.automl.logger: 04-28 02:53:39] {2364} INFO - at 113.5s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl.logger: 04-28 02:53:39] {2191} INFO - iteration 1, current learner lgbm\n",
"[flaml.automl.logger: 04-28 02:55:32] {2364} INFO - at 226.6s,\testimator lgbm's best error=0.1513,\tbest estimator lgbm's best error=0.1513\n",
"[flaml.automl.logger: 04-28 02:55:54] {2600} INFO - retrain lgbm for 22.3s\n",
"[flaml.automl.logger: 04-28 02:55:54] {2603} INFO - retrained model: LGBMRegressor(colsample_bytree=0.7019911744574896,\n",
" learning_rate=0.02263575841107852, max_bin=511,\n",
" min_child_samples=2, n_estimators=4797, num_leaves=122,\n",
" reg_alpha=0.004252223402511765, reg_lambda=0.11288241427227633,\n",
" reg_alpha=0.004252223402511765, reg_lambda=0.11288241427227624,\n",
" verbose=-1)\n",
"[flaml.automl: 05-31 23:05:14] {2617} INFO - fit succeeded\n",
"[flaml.automl: 05-31 23:05:14] {2618} INFO - Time taken to find the best model: 149.06516432762146\n"
"[flaml.automl.logger: 04-28 02:55:54] {1911} INFO - fit succeeded\n",
"[flaml.automl.logger: 04-28 02:55:54] {1912} INFO - Time taken to find the best model: 113.4601559638977\n"
]
}
],
@ -545,7 +610,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15 (main, Oct 26 2022, 03:47:43) \n[GCC 10.2.1 20210110]"
"version": "3.9.15"
}
},
"nbformat": 4,

View File

@ -15,11 +15,6 @@ __version__ = version["__version__"]
install_requires = [
"NumPy>=1.17.0rc1",
"lightgbm>=2.3.1",
"xgboost>=0.90",
"scipy>=1.4.1",
"pandas>=1.1.4",
"scikit-learn>=0.24",
]
@ -39,16 +34,28 @@ setuptools.setup(
include_package_data=True,
install_requires=install_requires,
extras_require={
"automl": [
"lightgbm>=2.3.1",
"xgboost>=0.90",
"scipy>=1.4.1",
"pandas>=1.1.4",
"scikit-learn>=0.24",
],
"notebook": [
"jupyter",
"matplotlib",
"openml==0.10.2",
"openml",
],
"spark": [
"pyspark>=3.2.0",
"joblibspark>=0.5.0",
],
"test": [
"lightgbm>=2.3.1",
"xgboost>=0.90",
"scipy>=1.4.1",
"pandas>=1.1.4",
"scikit-learn>=0.24",
"thop",
"pytest>=6.1.1",
"coverage>=5.3",
@ -58,7 +65,7 @@ setuptools.setup(
"catboost>=0.26,<1.2",
"rgf-python",
"optuna==2.8.0",
"openml==0.10.2",
"openml",
"statsmodels>=0.12.2",
"psutil==5.8.0",
"dataclasses",
@ -77,6 +84,7 @@ setuptools.setup(
"ipykernel",
"pytorch-lightning<1.9.1", # test_forecast_panel
"requests<2.29.0", # https://github.com/docker/docker-py/issues/3113
"packaging",
],
"catboost": ["catboost>=0.26"],
"blendsearch": ["optuna==2.8.0"],
@ -91,6 +99,7 @@ setuptools.setup(
],
"vw": [
"vowpalwabbit>=8.10.0, <9.0.0",
"scikit-learn",
],
"hf": [
"transformers[torch]==4.26",
@ -122,7 +131,12 @@ setuptools.setup(
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
"openai": ["openai==0.27.4", "diskcache"],
"autogen": ["openai==0.27.4", "diskcache", "docker"],
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.2.0"],
"synapse": [
"joblibspark>=0.5.0",
"optuna==2.8.0",
"pyspark>=3.2.0",
],
"autozero": ["scikit-learn", "pandas", "packaging"],
},
classifiers=[
"Programming Language :: Python :: 3",

View File

@ -1,20 +0,0 @@
import numpy as np
from flaml.automl.utils import len_labels, unique_value_first_index
def test_len_labels():
assert len_labels([1, 2, 3]) == 3
assert len_labels([1, 2, 3, 1, 2, 3]) == 3
assert np.array_equal(len_labels([1, 2, 3], True)[1], [1, 2, 3])
assert np.array_equal(len_labels([1, 2, 3, 1, 2, 3], True)[1], [1, 2, 3])
def test_unique_value_first_index():
label_set, first_index = unique_value_first_index([1, 2, 2, 3])
assert np.array_equal(label_set, np.array([1, 2, 3]))
assert np.array_equal(first_index, np.array([0, 1, 3]))
if __name__ == "__main__":
test_len_labels()
test_unique_value_first_index()

View File

@ -7,7 +7,7 @@ In this example, we will tune several hyperparameters for the OpenAI's completio
Install the [autogen,blendsearch] option.
```bash
pip install "flaml[autogen,blendsearch]==1.2.2 datasets"
pip install "flaml[autogen,blendsearch] datasets"
```
Setup your OpenAI key:

View File

@ -1,5 +1,12 @@
# AutoML - Classification
### Prerequisites
Install the [automl] option.
```bash
pip install "flaml[automl]"
```
### A basic classification example
```python

View File

@ -2,9 +2,9 @@
### Requirements
This example requires GPU. Install the [hf] option:
This example requires GPU. Install the [automl,hf] option:
```python
pip install "flaml[hf]"
pip install "flaml[automl,hf]"
```
### A simple sequence classification example

View File

@ -1,5 +1,12 @@
# AutoML - Rank
### Prerequisites
Install the [automl] option.
```bash
pip install "flaml[automl]"
```
### A simple learning-to-rank example
```python

View File

@ -1,5 +1,12 @@
# AutoML - Regression
### Prerequisites
Install the [automl] option.
```bash
pip install "flaml[automl]"
```
### A basic regression example
```python

View File

@ -2,9 +2,9 @@
### Prerequisites
Install the [ts_forecast] option.
Install the [automl,ts_forecast] option.
```bash
pip install "flaml[ts_forecast]"
pip install "flaml[automl,ts_forecast]"
```
### Simple NumPy Example

View File

@ -2,13 +2,11 @@
### Prerequisites for this example
Install the [notebook] option.
Install the [automl] option.
```bash
pip install "flaml[notebook]"
pip install "flaml[automl] matplotlib openml"
```
This option is not necessary in general.
### Use built-in LGBMEstimator
```python

View File

@ -2,13 +2,11 @@
### Prerequisites for this example
Install the [notebook] option.
Install the [automl] option.
```bash
pip install "flaml[notebook]"
pip install "flaml[automl] matplotlib openml"
```
This option is not necessary in general.
### Use built-in XGBoostSklearnEstimator
```python

View File

@ -2,10 +2,16 @@
Flamlized estimators automatically use data-dependent default hyperparameter configurations for each estimator, offering a unique zero-shot AutoML capability, or "no tuning" AutoML.
This example requires openml==0.10.2.
## Flamlized LGBMRegressor
### Prerequisites
This example requires the [autozero] option.
```bash
pip install flaml[autozero] lightgbm openml
```
### Zero-shot AutoML
```python
@ -62,6 +68,10 @@ X_test.shape: (5160, 8), y_test.shape: (5160,)
## Flamlized XGBClassifier
### Prerequisites
This example requires xgboost, sklearn, openml==0.10.2.
### Zero-shot AutoML
```python

View File

@ -2,9 +2,9 @@ FLAML can be used together with AzureML. On top of that, using mlflow and ray is
### Prerequisites
Install the [azureml] option.
Install the [automl,azureml] option.
```bash
pip install "flaml[azureml]"
pip install "flaml[automl,azureml]"
```
Setup a AzureML workspace:

View File

@ -1,6 +1,11 @@
As FLAML's AutoML module can be used a transformer in the Sklearn's pipeline we can get all the benefits of pipeline.
This example requires openml==0.10.2.
### Prerequisites
Install the [automl] option.
```bash
pip install "flaml[automl] openml"
```
### Load data