mirror of https://github.com/microsoft/autogen.git
Support spark dataframe as input dataset and spark models as estimators (#934)
* add basic support to Spark dataframe add support to SynapseML LightGBM model update to pyspark>=3.2.0 to leverage pandas_on_Spark API * clean code, add TODOs * add sample_train_data for pyspark.pandas dataframe, fix bugs * improve some functions, fix bugs * fix dict change size during iteration * update model predict * update LightGBM model, update test * update SynapseML LightGBM params * update synapseML and tests * update TODOs * Added support to roc_auc for spark models * Added support to score of spark estimator * Added test for automl score of spark estimator * Added cv support to pyspark.pandas dataframe * Update test, fix bugs * Added tests * Updated docs, tests, added a notebook * Fix bugs in non-spark env * Fix bugs and improve tests * Fix uninstall pyspark * Fix tests error * Fix java.lang.OutOfMemoryError: Java heap space * Fix test_performance * Update test_sparkml to test_0sparkml to use the expected spark conf * Remove unnecessary widgets in notebook * Fix iloc java.lang.StackOverflowError * fix pre-commit * Added params check for spark dataframes * Refactor code for train_test_split to a function * Update train_test_split_pyspark * Refactor if-else, remove unnecessary code * Remove y from predict, remove mem control from n_iter compute * Update workflow * Improve _split_pyspark * Fix test failure of too short training time * Fix typos, improve docstrings * Fix index errors of pandas_on_spark, add spark loss metric * Fix typo of ndcgAtK * Update NDCG metrics and tests * Remove unuseful logger * Use cache and count to ensure consistent indexes * refactor for merge maain * fix errors of refactor * Updated SparkLightGBMEstimator and cache * Updated config2params * Remove unused import * Fix unknown parameters * Update default_estimator_list * Add unit tests for spark metrics
This commit is contained in:
parent
a3e770eac5
commit
50334f2c52
|
@ -25,7 +25,6 @@ jobs:
|
||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||||
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
python-version: ["3.7", "3.8", "3.9", "3.10"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
@ -45,21 +44,18 @@ jobs:
|
||||||
export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include"
|
export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include"
|
||||||
export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include"
|
export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include"
|
||||||
export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp"
|
export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp"
|
||||||
- name: On Linux, install Spark stand-alone cluster and PySpark
|
- name: On Linux + python 3.8, install pyspark 3.2.3
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.8'
|
||||||
run: |
|
run: |
|
||||||
sudo apt-get update && sudo apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends ca-certificates-java ca-certificates openjdk-17-jdk-headless && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
|
python -m pip install --upgrade pip wheel
|
||||||
wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz?action=download" -O - | tar -xzC /tmp; archive=$(basename "spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz") bash -c "sudo mv -v /tmp/\${archive/%.tgz/} /spark"
|
pip install pyspark==3.2.3
|
||||||
pip install --no-cache-dir pyspark>=3.0
|
|
||||||
export SPARK_HOME=/spark
|
|
||||||
export PYTHONPATH=/spark/python/lib/py4j-0.10.9.5-src.zip:/spark/python
|
|
||||||
export PATH=$PATH:$SPARK_HOME/bin
|
|
||||||
- name: Install packages and dependencies
|
- name: Install packages and dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip wheel
|
python -m pip install --upgrade pip wheel
|
||||||
pip install -e .
|
pip install -e .
|
||||||
python -c "import flaml"
|
python -c "import flaml"
|
||||||
pip install -e .[test]
|
pip install -e .[test]
|
||||||
|
pip list | grep "pyspark"
|
||||||
- name: If linux, install ray 2
|
- name: If linux, install ray 2
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: |
|
run: |
|
||||||
|
@ -76,6 +72,11 @@ jobs:
|
||||||
if: matrix.python-version != '3.10'
|
if: matrix.python-version != '3.10'
|
||||||
run: |
|
run: |
|
||||||
pip install -e .[vw]
|
pip install -e .[vw]
|
||||||
|
- name: Uninstall pyspark on python 3.9
|
||||||
|
if: matrix.python-version == '3.9'
|
||||||
|
run: |
|
||||||
|
# Uninstall pyspark to test env without pyspark
|
||||||
|
pip uninstall -y pyspark
|
||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
# stop the build if there are Python syntax errors or undefined names
|
# stop the build if there are Python syntax errors or undefined names
|
||||||
|
|
|
@ -159,3 +159,6 @@ automl.pkl
|
||||||
|
|
||||||
test/nlp/testtmp.py
|
test/nlp/testtmp.py
|
||||||
test/nlp/testtmpfl.py
|
test/nlp/testtmpfl.py
|
||||||
|
|
||||||
|
flaml/tune/spark/mylearner.py
|
||||||
|
*.pkl
|
||||||
|
|
|
@ -7,7 +7,6 @@ import time
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Callable, List, Union, Optional
|
from typing import Callable, List, Union, Optional
|
||||||
import inspect
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.base import BaseEstimator
|
from sklearn.base import BaseEstimator
|
||||||
|
@ -17,7 +16,6 @@ import json
|
||||||
|
|
||||||
from flaml.automl.state import SearchState, AutoMLState
|
from flaml.automl.state import SearchState, AutoMLState
|
||||||
from flaml.automl.ml import (
|
from flaml.automl.ml import (
|
||||||
compute_estimator,
|
|
||||||
train_estimator,
|
train_estimator,
|
||||||
get_estimator_class,
|
get_estimator_class,
|
||||||
)
|
)
|
||||||
|
@ -31,7 +29,6 @@ from flaml.config import (
|
||||||
N_SPLITS,
|
N_SPLITS,
|
||||||
SAMPLE_MULTIPLY_FACTOR,
|
SAMPLE_MULTIPLY_FACTOR,
|
||||||
)
|
)
|
||||||
from flaml.automl.data import concat
|
|
||||||
|
|
||||||
# TODO check to see when we can remove these
|
# TODO check to see when we can remove these
|
||||||
from flaml.automl.task.task import CLASSIFICATION, TS_FORECAST, Task
|
from flaml.automl.task.task import CLASSIFICATION, TS_FORECAST, Task
|
||||||
|
@ -43,6 +40,34 @@ from flaml.default import suggest_learner
|
||||||
from flaml.version import __version__ as flaml_version
|
from flaml.version import __version__ as flaml_version
|
||||||
from flaml.tune.spark.utils import check_spark, get_broadcast_data
|
from flaml.tune.spark.utils import check_spark, get_broadcast_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flaml.automl.spark.utils import (
|
||||||
|
train_test_split_pyspark,
|
||||||
|
unique_pandas_on_spark,
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
train_test_split_pyspark = None
|
||||||
|
unique_pandas_on_spark = None
|
||||||
|
from flaml.automl.utils import (
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
from pyspark.pandas.config import set_option, reset_option
|
||||||
|
except ImportError:
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import mlflow
|
import mlflow
|
||||||
|
@ -511,7 +536,12 @@ class AutoML(BaseEstimator):
|
||||||
"""Time taken to find best model in seconds."""
|
"""Time taken to find best model in seconds."""
|
||||||
return self.__dict__.get("_time_taken_best_iter")
|
return self.__dict__.get("_time_taken_best_iter")
|
||||||
|
|
||||||
def score(self, X: pd.DataFrame, y: pd.Series, **kwargs):
|
def score(
|
||||||
|
self,
|
||||||
|
X: Union[pd.DataFrame, psDataFrame],
|
||||||
|
y: Union[pd.Series, psSeries],
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
estimator = getattr(self, "_trained_estimator", None)
|
estimator = getattr(self, "_trained_estimator", None)
|
||||||
if estimator is None:
|
if estimator is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
@ -525,13 +555,14 @@ class AutoML(BaseEstimator):
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
X: Union[np.array, pd.DataFrame, List[str], List[List[str]]],
|
X: Union[np.array, pd.DataFrame, List[str], List[List[str]], psDataFrame],
|
||||||
**pred_kwargs,
|
**pred_kwargs,
|
||||||
):
|
):
|
||||||
"""Predict label from features.
|
"""Predict label from features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
X: A numpy array of featurized instances, shape n * m,
|
X: A numpy array or pandas dataframe or pyspark.pandas dataframe
|
||||||
|
of featurized instances, shape n * m,
|
||||||
or for time series forcast tasks:
|
or for time series forcast tasks:
|
||||||
a pandas dataframe with the first column containing
|
a pandas dataframe with the first column containing
|
||||||
timestamp values (datetime type) or an integer n for
|
timestamp values (datetime type) or an integer n for
|
||||||
|
@ -1859,7 +1890,19 @@ class AutoML(BaseEstimator):
|
||||||
error_metric = "customized metric"
|
error_metric = "customized metric"
|
||||||
logger.info(f"Minimizing error metric: {error_metric}")
|
logger.info(f"Minimizing error metric: {error_metric}")
|
||||||
|
|
||||||
estimator_list = task.default_estimator_list(estimator_list)
|
is_spark_dataframe = isinstance(X_train, psDataFrame) or isinstance(
|
||||||
|
dataframe, psDataFrame
|
||||||
|
)
|
||||||
|
estimator_list = task.default_estimator_list(estimator_list, is_spark_dataframe)
|
||||||
|
|
||||||
|
if is_spark_dataframe and self._use_spark:
|
||||||
|
# For spark dataframe, use_spark must be False because spark models are trained in parallel themselves
|
||||||
|
self._use_spark = False
|
||||||
|
logger.warning(
|
||||||
|
"Spark dataframes support only spark.ml type models, which will be trained "
|
||||||
|
"with spark themselves, no need to start spark trials in flaml. "
|
||||||
|
"`use_spark` is set to False."
|
||||||
|
)
|
||||||
|
|
||||||
# When no search budget is specified
|
# When no search budget is specified
|
||||||
if no_budget:
|
if no_budget:
|
||||||
|
|
|
@ -12,6 +12,22 @@ from flaml.automl.training_log import training_log_reader
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
except ImportError:
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from flaml.automl.task import Task
|
from flaml.automl.task import Task
|
||||||
|
|
||||||
|
@ -198,6 +214,15 @@ def get_output_from_log(filename, time_budget):
|
||||||
|
|
||||||
def concat(X1, X2):
|
def concat(X1, X2):
|
||||||
"""concatenate two matrices vertically."""
|
"""concatenate two matrices vertically."""
|
||||||
|
if type(X1) != type(X2):
|
||||||
|
if isinstance(X2, (psDataFrame, psSeries)):
|
||||||
|
X1 = ps.from_pandas(pd.DataFrame(X1))
|
||||||
|
elif isinstance(X1, (psDataFrame, psSeries)):
|
||||||
|
X2 = ps.from_pandas(pd.DataFrame(X2))
|
||||||
|
else:
|
||||||
|
X1 = pd.DataFrame(X1)
|
||||||
|
X2 = pd.DataFrame(X2)
|
||||||
|
|
||||||
if isinstance(X1, (DataFrame, Series)):
|
if isinstance(X1, (DataFrame, Series)):
|
||||||
df = pd.concat([X1, X2], sort=False)
|
df = pd.concat([X1, X2], sort=False)
|
||||||
df.reset_index(drop=True, inplace=True)
|
df.reset_index(drop=True, inplace=True)
|
||||||
|
@ -206,6 +231,13 @@ def concat(X1, X2):
|
||||||
if len(cat_columns):
|
if len(cat_columns):
|
||||||
df[cat_columns] = df[cat_columns].astype("category")
|
df[cat_columns] = df[cat_columns].astype("category")
|
||||||
return df
|
return df
|
||||||
|
if isinstance(X1, (psDataFrame, psSeries)):
|
||||||
|
df = ps.concat([X1, X2], ignore_index=True)
|
||||||
|
if isinstance(X1, psDataFrame):
|
||||||
|
cat_columns = X1.select_dtypes(include="category").columns.values.tolist()
|
||||||
|
if len(cat_columns):
|
||||||
|
df[cat_columns] = df[cat_columns].astype("category")
|
||||||
|
return df
|
||||||
if issparse(X1):
|
if issparse(X1):
|
||||||
return vstack((X1, X2))
|
return vstack((X1, X2))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
# * Copyright (c) FLAML authors. All rights reserved.
|
# * Copyright (c) FLAML authors. All rights reserved.
|
||||||
# * Licensed under the MIT License. See LICENSE file in the
|
# * Licensed under the MIT License. See LICENSE file in the
|
||||||
# * project root for license information.
|
# * project root for license information.
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -19,12 +20,6 @@ from sklearn.metrics import (
|
||||||
mean_absolute_percentage_error,
|
mean_absolute_percentage_error,
|
||||||
ndcg_score,
|
ndcg_score,
|
||||||
)
|
)
|
||||||
from sklearn.model_selection import (
|
|
||||||
RepeatedStratifiedKFold,
|
|
||||||
GroupKFold,
|
|
||||||
TimeSeriesSplit,
|
|
||||||
StratifiedGroupKFold,
|
|
||||||
)
|
|
||||||
from flaml.automl.model import (
|
from flaml.automl.model import (
|
||||||
XGBoostSklearnEstimator,
|
XGBoostSklearnEstimator,
|
||||||
XGBoost_TS,
|
XGBoost_TS,
|
||||||
|
@ -46,14 +41,33 @@ from flaml.automl.model import (
|
||||||
TransformersEstimator,
|
TransformersEstimator,
|
||||||
TemporalFusionTransformerEstimator,
|
TemporalFusionTransformerEstimator,
|
||||||
TransformersEstimatorModelSelection,
|
TransformersEstimatorModelSelection,
|
||||||
|
SparkLGBMEstimator,
|
||||||
)
|
)
|
||||||
from flaml.automl.data import group_counts
|
from flaml.automl.data import group_counts
|
||||||
from flaml.automl.task.task import TS_FORECAST, Task
|
from flaml.automl.task.task import TS_FORECAST, Task
|
||||||
from flaml.automl.model import BaseEstimator
|
from flaml.automl.model import BaseEstimator
|
||||||
|
|
||||||
import logging
|
try:
|
||||||
|
from flaml.automl.spark.utils import len_labels
|
||||||
|
except ImportError:
|
||||||
|
from flaml.automl.utils import len_labels
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
from pyspark.sql.functions import col
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
from flaml.automl.spark.utils import to_pandas_on_spark, iloc_pandas_on_spark
|
||||||
|
from flaml.automl.spark.metrics import spark_metric_loss_score
|
||||||
|
except ImportError:
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
EstimatorSubclass = TypeVar("EstimatorSubclass", bound=BaseEstimator)
|
EstimatorSubclass = TypeVar("EstimatorSubclass", bound=BaseEstimator)
|
||||||
|
|
||||||
sklearn_metric_name_set = {
|
sklearn_metric_name_set = {
|
||||||
|
@ -124,6 +138,8 @@ def get_estimator_class(task: str, estimator_name: str) -> EstimatorSubclass:
|
||||||
estimator_class = RF_TS if task in TS_FORECAST else RandomForestEstimator
|
estimator_class = RF_TS if task in TS_FORECAST else RandomForestEstimator
|
||||||
elif "lgbm" == estimator_name:
|
elif "lgbm" == estimator_name:
|
||||||
estimator_class = LGBM_TS if task in TS_FORECAST else LGBMEstimator
|
estimator_class = LGBM_TS if task in TS_FORECAST else LGBMEstimator
|
||||||
|
elif "lgbm_spark" == estimator_name:
|
||||||
|
estimator_class = SparkLGBMEstimator
|
||||||
elif "lrl1" == estimator_name:
|
elif "lrl1" == estimator_name:
|
||||||
estimator_class = LRL1Classifier
|
estimator_class = LRL1Classifier
|
||||||
elif "lrl2" == estimator_name:
|
elif "lrl2" == estimator_name:
|
||||||
|
@ -163,7 +179,15 @@ def metric_loss_score(
|
||||||
groups=None,
|
groups=None,
|
||||||
):
|
):
|
||||||
# y_processed_predict and y_processed_true are processed id labels if the original were the token labels
|
# y_processed_predict and y_processed_true are processed id labels if the original were the token labels
|
||||||
if is_in_sklearn_metric_name_set(metric_name):
|
if isinstance(y_processed_predict, (psDataFrame, psSeries)):
|
||||||
|
return spark_metric_loss_score(
|
||||||
|
metric_name,
|
||||||
|
y_processed_predict,
|
||||||
|
y_processed_true,
|
||||||
|
sample_weight,
|
||||||
|
groups,
|
||||||
|
)
|
||||||
|
elif is_in_sklearn_metric_name_set(metric_name):
|
||||||
return sklearn_metric_loss_score(
|
return sklearn_metric_loss_score(
|
||||||
metric_name,
|
metric_name,
|
||||||
y_processed_predict,
|
y_processed_predict,
|
||||||
|
@ -359,7 +383,10 @@ def sklearn_metric_loss_score(
|
||||||
def get_y_pred(estimator, X, eval_metric, task: Task):
|
def get_y_pred(estimator, X, eval_metric, task: Task):
|
||||||
if eval_metric in ["roc_auc", "ap", "roc_auc_weighted"] and task.is_binary():
|
if eval_metric in ["roc_auc", "ap", "roc_auc_weighted"] and task.is_binary():
|
||||||
y_pred_classes = estimator.predict_proba(X)
|
y_pred_classes = estimator.predict_proba(X)
|
||||||
y_pred = y_pred_classes[:, 1] if y_pred_classes.ndim > 1 else y_pred_classes
|
if isinstance(y_pred_classes, (psSeries, psDataFrame)):
|
||||||
|
y_pred = y_pred_classes
|
||||||
|
else:
|
||||||
|
y_pred = y_pred_classes[:, 1] if y_pred_classes.ndim > 1 else y_pred_classes
|
||||||
elif eval_metric in [
|
elif eval_metric in [
|
||||||
"log_loss",
|
"log_loss",
|
||||||
"roc_auc",
|
"roc_auc",
|
||||||
|
@ -525,7 +552,7 @@ def compute_estimator(
|
||||||
fit_kwargs: Optional[dict] = None,
|
fit_kwargs: Optional[dict] = None,
|
||||||
free_mem_ratio=0,
|
free_mem_ratio=0,
|
||||||
):
|
):
|
||||||
if not fit_kwargs:
|
if fit_kwargs is None:
|
||||||
fit_kwargs = {}
|
fit_kwargs = {}
|
||||||
|
|
||||||
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
|
estimator_class = estimator_class or get_estimator_class(task, estimator_name)
|
||||||
|
@ -605,7 +632,7 @@ def train_estimator(
|
||||||
task=task,
|
task=task,
|
||||||
n_jobs=n_jobs,
|
n_jobs=n_jobs,
|
||||||
)
|
)
|
||||||
if not fit_kwargs:
|
if fit_kwargs is None:
|
||||||
fit_kwargs = {}
|
fit_kwargs = {}
|
||||||
|
|
||||||
if isinstance(estimator, TransformersEstimator):
|
if isinstance(estimator, TransformersEstimator):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import signal
|
import signal
|
||||||
import os
|
import os
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Union
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
|
||||||
|
@ -36,6 +36,38 @@ from flaml.automl.task.task import (
|
||||||
NLG_TASKS,
|
NLG_TASKS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flaml.automl.spark.utils import len_labels, to_pandas_on_spark
|
||||||
|
except ImportError:
|
||||||
|
from flaml.automl.utils import len_labels
|
||||||
|
|
||||||
|
to_pandas_on_spark = None
|
||||||
|
from flaml.automl.spark.configs import (
|
||||||
|
ParamList_LightGBM_Classifier,
|
||||||
|
ParamList_LightGBM_Regressor,
|
||||||
|
ParamList_LightGBM_Ranker,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
from pyspark.sql.dataframe import DataFrame as sparkDataFrame
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
|
||||||
|
_have_spark = True
|
||||||
|
except ImportError:
|
||||||
|
_have_spark = False
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class sparkDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import psutil
|
import psutil
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -388,6 +420,323 @@ class BaseEstimator:
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
class SparkEstimator(BaseEstimator):
|
||||||
|
"""The base class for fine-tuning spark models, using pyspark.ml and SynapseML API."""
|
||||||
|
|
||||||
|
def __init__(self, task="binary", **config):
|
||||||
|
if not _have_spark:
|
||||||
|
raise ImportError(
|
||||||
|
"pyspark is not installed. Try `pip install flaml[spark]`."
|
||||||
|
)
|
||||||
|
super().__init__(task, **config)
|
||||||
|
self.df_train = None
|
||||||
|
|
||||||
|
def _preprocess(
|
||||||
|
self,
|
||||||
|
X_train: Union[psDataFrame, sparkDataFrame],
|
||||||
|
y_train: psSeries = None,
|
||||||
|
index_col: str = "tmp_index_col",
|
||||||
|
):
|
||||||
|
# TODO: optimize this, support pyspark.sql.DataFrame
|
||||||
|
if y_train is not None:
|
||||||
|
self.df_train = X_train.join(y_train)
|
||||||
|
else:
|
||||||
|
self.df_train = X_train
|
||||||
|
if isinstance(self.df_train, psDataFrame):
|
||||||
|
self.df_train = self.df_train.to_spark(index_col=index_col)
|
||||||
|
return self.df_train
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X_train: psDataFrame,
|
||||||
|
y_train: psSeries = None,
|
||||||
|
budget=None,
|
||||||
|
free_mem_ratio=0,
|
||||||
|
index_col: str = "tmp_index_col",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Train the model from given training data.
|
||||||
|
Args:
|
||||||
|
X_train: A pyspark.pandas DataFrame of training data in shape n*m.
|
||||||
|
y_train: A pyspark.pandas Series in shape n*1. None if X_train is a pyspark.pandas
|
||||||
|
Dataframe contains y_train.
|
||||||
|
budget: A float of the time budget in seconds.
|
||||||
|
free_mem_ratio: A float between 0 and 1 for the free memory ratio to keep during training.
|
||||||
|
Returns:
|
||||||
|
train_time: A float of the training time in seconds.
|
||||||
|
"""
|
||||||
|
df_train = self._preprocess(X_train, y_train, index_col=index_col)
|
||||||
|
train_time = self._fit(df_train, **kwargs)
|
||||||
|
return train_time
|
||||||
|
|
||||||
|
def _fit(self, df_train: sparkDataFrame, **kwargs):
|
||||||
|
current_time = time.time()
|
||||||
|
pipeline_model = self.estimator_class(**self.params, **kwargs)
|
||||||
|
if logger.level == logging.DEBUG:
|
||||||
|
logger.debug(
|
||||||
|
f"flaml.model - {pipeline_model} fit started with params {self.params}"
|
||||||
|
)
|
||||||
|
pipeline_model.fit(df_train)
|
||||||
|
if logger.level == logging.DEBUG:
|
||||||
|
logger.debug(f"flaml.model - {pipeline_model} fit finished")
|
||||||
|
train_time = time.time() - current_time
|
||||||
|
self._model = pipeline_model
|
||||||
|
return train_time
|
||||||
|
|
||||||
|
def predict(self, X, index_col="tmp_index_col", return_all=False, **kwargs):
|
||||||
|
"""Predict label from features.
|
||||||
|
Args:
|
||||||
|
X: A pyspark or pyspark.pandas dataframe of featurized instances, shape n*m.
|
||||||
|
index_col: A str of the index column name. Default to "tmp_index_col".
|
||||||
|
return_all: A bool of whether to return all the prediction results. Default to False.
|
||||||
|
Returns:
|
||||||
|
A pyspark.pandas series of shape n*1 if return_all is False. Otherwise, a pyspark.pandas dataframe.
|
||||||
|
"""
|
||||||
|
if self._model is not None:
|
||||||
|
X = self._preprocess(X, index_col=index_col)
|
||||||
|
predictions = to_pandas_on_spark(
|
||||||
|
self._model.transform(X), index_col=index_col
|
||||||
|
)
|
||||||
|
predictions.index.name = None
|
||||||
|
pred_y = predictions["prediction"]
|
||||||
|
if return_all:
|
||||||
|
return predictions
|
||||||
|
else:
|
||||||
|
return pred_y
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Estimator is not fit yet. Please run fit() before predict()."
|
||||||
|
)
|
||||||
|
return np.ones(X.shape[0])
|
||||||
|
|
||||||
|
def predict_proba(self, X, index_col="tmp_index_col", return_all=False, **kwargs):
|
||||||
|
"""Predict the probability of each class from features.
|
||||||
|
Only works for classification problems
|
||||||
|
Args:
|
||||||
|
X: A pyspark or pyspark.pandas dataframe of featurized instances, shape n*m.
|
||||||
|
index_col: A str of the index column name. Default to "tmp_index_col".
|
||||||
|
return_all: A bool of whether to return all the prediction results. Default to False.
|
||||||
|
Returns:
|
||||||
|
A pyspark.pandas dataframe of shape n*c. c is the # classes.
|
||||||
|
Each element at (i,j) is the probability for instance i to be in
|
||||||
|
class j.
|
||||||
|
"""
|
||||||
|
assert self._task in CLASSIFICATION, "predict_proba() only for classification."
|
||||||
|
if self._model is not None:
|
||||||
|
X = self._preprocess(X, index_col=index_col)
|
||||||
|
predictions = to_pandas_on_spark(
|
||||||
|
self._model.transform(X), index_col=index_col
|
||||||
|
)
|
||||||
|
predictions.index.name = None
|
||||||
|
pred_y = predictions["probability"]
|
||||||
|
|
||||||
|
if return_all:
|
||||||
|
return predictions
|
||||||
|
else:
|
||||||
|
return pred_y
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Estimator is not fit yet. Please run fit() before predict()."
|
||||||
|
)
|
||||||
|
return np.ones(X.shape[0])
|
||||||
|
|
||||||
|
|
||||||
|
class SparkLGBMEstimator(SparkEstimator):
|
||||||
|
"""The class for fine-tuning spark version lightgbm models, using SynapseML API."""
|
||||||
|
|
||||||
|
"""The class for tuning LGBM, using sklearn API."""
|
||||||
|
|
||||||
|
ITER_HP = "numIterations"
|
||||||
|
DEFAULT_ITER = 100
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def search_space(cls, data_size, **params):
|
||||||
|
upper = max(5, min(32768, int(data_size[0]))) # upper must be larger than lower
|
||||||
|
# https://github.com/microsoft/SynapseML/blob/master/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala
|
||||||
|
return {
|
||||||
|
"numIterations": {
|
||||||
|
"domain": tune.lograndint(lower=4, upper=upper),
|
||||||
|
"init_value": 4,
|
||||||
|
"low_cost_init_value": 4,
|
||||||
|
},
|
||||||
|
"numLeaves": {
|
||||||
|
"domain": tune.lograndint(lower=4, upper=upper),
|
||||||
|
"init_value": 4,
|
||||||
|
"low_cost_init_value": 4,
|
||||||
|
},
|
||||||
|
"minDataInLeaf": {
|
||||||
|
"domain": tune.lograndint(lower=2, upper=2**7 + 1),
|
||||||
|
"init_value": 20,
|
||||||
|
},
|
||||||
|
"learningRate": {
|
||||||
|
"domain": tune.loguniform(lower=1 / 1024, upper=1.0),
|
||||||
|
"init_value": 0.1,
|
||||||
|
},
|
||||||
|
"log_max_bin": { # log transformed with base 2
|
||||||
|
"domain": tune.lograndint(lower=3, upper=11),
|
||||||
|
"init_value": 8,
|
||||||
|
},
|
||||||
|
"featureFraction": {
|
||||||
|
"domain": tune.uniform(lower=0.01, upper=1.0),
|
||||||
|
"init_value": 1.0,
|
||||||
|
},
|
||||||
|
"lambdaL1": {
|
||||||
|
"domain": tune.loguniform(lower=1 / 1024, upper=1024),
|
||||||
|
"init_value": 1 / 1024,
|
||||||
|
},
|
||||||
|
"lambdaL2": {
|
||||||
|
"domain": tune.loguniform(lower=1 / 1024, upper=1024),
|
||||||
|
"init_value": 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def config2params(self, config: dict) -> dict:
|
||||||
|
params = super().config2params(config)
|
||||||
|
if "n_jobs" in params:
|
||||||
|
params.pop("n_jobs")
|
||||||
|
if "log_max_bin" in params:
|
||||||
|
params["maxBin"] = (1 << params.pop("log_max_bin")) - 1
|
||||||
|
return params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def size(cls, config):
|
||||||
|
num_leaves = int(
|
||||||
|
round(config.get("numLeaves") or 1 << config.get("maxDepth", 16))
|
||||||
|
)
|
||||||
|
n_estimators = int(round(config["numIterations"]))
|
||||||
|
return (num_leaves * 3 + (num_leaves - 1) * 4 + 1.0) * n_estimators * 8
|
||||||
|
|
||||||
|
def __init__(self, task="binary", **config):
|
||||||
|
super().__init__(task, **config)
|
||||||
|
err_msg = (
|
||||||
|
"SynapseML is not installed. Please refer to [SynapseML]"
|
||||||
|
+ "(https://github.com/microsoft/SynapseML) for installation instructions."
|
||||||
|
)
|
||||||
|
if "regression" == task:
|
||||||
|
try:
|
||||||
|
from synapse.ml.lightgbm import LightGBMRegressor
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(err_msg)
|
||||||
|
|
||||||
|
self.estimator_class = LightGBMRegressor
|
||||||
|
self.estimator_params = ParamList_LightGBM_Regressor
|
||||||
|
elif "rank" == task:
|
||||||
|
try:
|
||||||
|
from synapse.ml.lightgbm import LightGBMRanker
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(err_msg)
|
||||||
|
|
||||||
|
self.estimator_class = LightGBMRanker
|
||||||
|
self.estimator_params = ParamList_LightGBM_Ranker
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from synapse.ml.lightgbm import LightGBMClassifier
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(err_msg)
|
||||||
|
|
||||||
|
self.estimator_class = LightGBMClassifier
|
||||||
|
self.estimator_params = ParamList_LightGBM_Classifier
|
||||||
|
self._time_per_iter = None
|
||||||
|
self._train_size = 0
|
||||||
|
self._mem_per_iter = -1
|
||||||
|
self.model_classes_ = None
|
||||||
|
self.model_n_classes_ = None
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X_train,
|
||||||
|
y_train=None,
|
||||||
|
budget=None,
|
||||||
|
free_mem_ratio=0,
|
||||||
|
index_col="tmp_index_col",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
start_time = time.time()
|
||||||
|
if self.model_n_classes_ is None and self._task not in ["regression", "rank"]:
|
||||||
|
self.model_n_classes_, self.model_classes_ = len_labels(
|
||||||
|
y_train, return_labels=True
|
||||||
|
)
|
||||||
|
df_train = self._preprocess(X_train, y_train, index_col=index_col)
|
||||||
|
# n_iter = self.params.get(self.ITER_HP, self.DEFAULT_ITER)
|
||||||
|
# trained = False
|
||||||
|
# mem0 = psutil.virtual_memory().available if psutil is not None else 1
|
||||||
|
_kwargs = kwargs.copy()
|
||||||
|
if self._task not in ["regression", "rank"] and "objective" not in _kwargs:
|
||||||
|
_kwargs["objective"] = (
|
||||||
|
"binary" if self.model_n_classes_ == 2 else "multiclass"
|
||||||
|
)
|
||||||
|
for k in list(_kwargs.keys()):
|
||||||
|
if k not in self.estimator_params:
|
||||||
|
logger.warning(
|
||||||
|
f"[SparkLGBMEstimator] [Warning] Ignored unknown parameter: {k}"
|
||||||
|
)
|
||||||
|
_kwargs.pop(k)
|
||||||
|
# TODO: find a better estimation of early stopping
|
||||||
|
# if (
|
||||||
|
# (not self._time_per_iter or abs(self._train_size - df_train.count()) > 4)
|
||||||
|
# and budget is not None
|
||||||
|
# or self._mem_per_iter < 0
|
||||||
|
# and psutil is not None
|
||||||
|
# ) and n_iter > 1:
|
||||||
|
# self.params[self.ITER_HP] = 1
|
||||||
|
# self._t1 = self._fit(df_train, **_kwargs)
|
||||||
|
# if budget is not None and self._t1 >= budget or n_iter == 1:
|
||||||
|
# return self._t1
|
||||||
|
# mem1 = psutil.virtual_memory().available if psutil is not None else 1
|
||||||
|
# self._mem1 = mem0 - mem1
|
||||||
|
# self.params[self.ITER_HP] = min(n_iter, 4)
|
||||||
|
# self._t2 = self._fit(df_train, **_kwargs)
|
||||||
|
# mem2 = psutil.virtual_memory().available if psutil is not None else 1
|
||||||
|
# self._mem2 = max(mem0 - mem2, self._mem1)
|
||||||
|
# self._mem_per_iter = min(self._mem1, self._mem2 / self.params[self.ITER_HP])
|
||||||
|
# self._time_per_iter = (
|
||||||
|
# (self._t2 - self._t1) / (self.params[self.ITER_HP] - 1)
|
||||||
|
# if self._t2 > self._t1
|
||||||
|
# else self._t1
|
||||||
|
# if self._t1
|
||||||
|
# else 0.001
|
||||||
|
# )
|
||||||
|
# self._train_size = df_train.count()
|
||||||
|
# if (
|
||||||
|
# budget is not None
|
||||||
|
# and self._t1 + self._t2 >= budget
|
||||||
|
# or n_iter == self.params[self.ITER_HP]
|
||||||
|
# ):
|
||||||
|
# # self.params[self.ITER_HP] = n_iter
|
||||||
|
# return time.time() - start_time
|
||||||
|
# trained = True
|
||||||
|
# if n_iter > 1:
|
||||||
|
# max_iter = min(
|
||||||
|
# n_iter,
|
||||||
|
# int(
|
||||||
|
# (budget - time.time() + start_time - self._t1) / self._time_per_iter
|
||||||
|
# + 1
|
||||||
|
# )
|
||||||
|
# if budget is not None
|
||||||
|
# else n_iter,
|
||||||
|
# )
|
||||||
|
# if trained and max_iter <= self.params[self.ITER_HP]:
|
||||||
|
# return time.time() - start_time
|
||||||
|
# # when not trained, train at least one iter
|
||||||
|
# self.params[self.ITER_HP] = max(max_iter, 1)
|
||||||
|
self._fit(df_train, **_kwargs)
|
||||||
|
train_time = time.time() - start_time
|
||||||
|
return train_time
|
||||||
|
|
||||||
|
def _fit(self, df_train: sparkDataFrame, **kwargs):
|
||||||
|
current_time = time.time()
|
||||||
|
model = self.estimator_class(**self.params, **kwargs)
|
||||||
|
if logger.level == logging.DEBUG:
|
||||||
|
logger.debug(f"flaml.model - {model} fit started with params {self.params}")
|
||||||
|
self._model = model.fit(df_train)
|
||||||
|
self._model.classes_ = self.model_classes_
|
||||||
|
self._model.n_classes_ = self.model_n_classes_
|
||||||
|
if logger.level == logging.DEBUG:
|
||||||
|
logger.debug(f"flaml.model - {model} fit finished")
|
||||||
|
train_time = time.time() - current_time
|
||||||
|
return train_time
|
||||||
|
|
||||||
|
|
||||||
class TransformersEstimator(BaseEstimator):
|
class TransformersEstimator(BaseEstimator):
|
||||||
"""The class for fine-tuning language models, using huggingface transformers API."""
|
"""The class for fine-tuning language models, using huggingface transformers API."""
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,97 @@
|
||||||
|
ParamList_LightGBM_Base = [
|
||||||
|
"baggingFraction",
|
||||||
|
"baggingFreq",
|
||||||
|
"baggingSeed",
|
||||||
|
"binSampleCount",
|
||||||
|
"boostFromAverage",
|
||||||
|
"boostingType",
|
||||||
|
"catSmooth",
|
||||||
|
"categoricalSlotIndexes",
|
||||||
|
"categoricalSlotNames",
|
||||||
|
"catl2",
|
||||||
|
"chunkSize",
|
||||||
|
"dataRandomSeed",
|
||||||
|
"defaultListenPort",
|
||||||
|
"deterministic",
|
||||||
|
"driverListenPort",
|
||||||
|
"dropRate",
|
||||||
|
"dropSeed",
|
||||||
|
"earlyStoppingRound",
|
||||||
|
"executionMode",
|
||||||
|
"extraSeed" "featureFraction",
|
||||||
|
"featureFractionByNode",
|
||||||
|
"featureFractionSeed",
|
||||||
|
"featuresCol",
|
||||||
|
"featuresShapCol",
|
||||||
|
"fobj" "improvementTolerance",
|
||||||
|
"initScoreCol",
|
||||||
|
"isEnableSparse",
|
||||||
|
"isProvideTrainingMetric",
|
||||||
|
"labelCol",
|
||||||
|
"lambdaL1",
|
||||||
|
"lambdaL2",
|
||||||
|
"leafPredictionCol",
|
||||||
|
"learningRate",
|
||||||
|
"matrixType",
|
||||||
|
"maxBin",
|
||||||
|
"maxBinByFeature",
|
||||||
|
"maxCatThreshold",
|
||||||
|
"maxCatToOnehot",
|
||||||
|
"maxDeltaStep",
|
||||||
|
"maxDepth",
|
||||||
|
"maxDrop",
|
||||||
|
"metric",
|
||||||
|
"microBatchSize",
|
||||||
|
"minDataInLeaf",
|
||||||
|
"minDataPerBin",
|
||||||
|
"minDataPerGroup",
|
||||||
|
"minGainToSplit",
|
||||||
|
"minSumHessianInLeaf",
|
||||||
|
"modelString",
|
||||||
|
"monotoneConstraints",
|
||||||
|
"monotoneConstraintsMethod",
|
||||||
|
"monotonePenalty",
|
||||||
|
"negBaggingFraction",
|
||||||
|
"numBatches",
|
||||||
|
"numIterations",
|
||||||
|
"numLeaves",
|
||||||
|
"numTasks",
|
||||||
|
"numThreads",
|
||||||
|
"objectiveSeed",
|
||||||
|
"otherRate",
|
||||||
|
"parallelism",
|
||||||
|
"passThroughArgs",
|
||||||
|
"posBaggingFraction",
|
||||||
|
"predictDisableShapeCheck",
|
||||||
|
"predictionCol",
|
||||||
|
"repartitionByGroupingColumn",
|
||||||
|
"seed",
|
||||||
|
"skipDrop",
|
||||||
|
"slotNames",
|
||||||
|
"timeout",
|
||||||
|
"topK",
|
||||||
|
"topRate",
|
||||||
|
"uniformDrop",
|
||||||
|
"useBarrierExecutionMode",
|
||||||
|
"useMissing",
|
||||||
|
"useSingleDatasetMode",
|
||||||
|
"validationIndicatorCol",
|
||||||
|
"verbosity",
|
||||||
|
"weightCol",
|
||||||
|
"xGBoostDartMode",
|
||||||
|
"zeroAsMissing",
|
||||||
|
"objective",
|
||||||
|
]
|
||||||
|
ParamList_LightGBM_Classifier = ParamList_LightGBM_Base + [
|
||||||
|
"isUnbalance",
|
||||||
|
"probabilityCol",
|
||||||
|
"rawPredictionCol",
|
||||||
|
"thresholds",
|
||||||
|
]
|
||||||
|
ParamList_LightGBM_Regressor = ParamList_LightGBM_Base + ["tweedieVariancePower"]
|
||||||
|
ParamList_LightGBM_Ranker = ParamList_LightGBM_Base + [
|
||||||
|
"groupCol",
|
||||||
|
"evalAt",
|
||||||
|
"labelGain",
|
||||||
|
"maxPosition",
|
||||||
|
]
|
|
@ -0,0 +1,230 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
from pyspark.sql import DataFrame
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.ml.evaluation import (
|
||||||
|
BinaryClassificationEvaluator,
|
||||||
|
RegressionEvaluator,
|
||||||
|
MulticlassClassificationEvaluator,
|
||||||
|
MultilabelClassificationEvaluator,
|
||||||
|
RankingEvaluator,
|
||||||
|
)
|
||||||
|
import pyspark.sql.functions as F
|
||||||
|
except ImportError:
|
||||||
|
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
|
||||||
|
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
|
||||||
|
for more details about installing Spark."""
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def ps_group_counts(groups: Union[ps.Series, np.ndarray]) -> np.ndarray:
|
||||||
|
if isinstance(groups, np.ndarray):
|
||||||
|
_, i, c = np.unique(groups, return_counts=True, return_index=True)
|
||||||
|
else:
|
||||||
|
i = groups.drop_duplicates().index.values
|
||||||
|
c = groups.value_counts().sort_index().to_numpy()
|
||||||
|
return c[np.argsort(i)].tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _process_df(df, label_col, prediction_col):
|
||||||
|
df = df.withColumn(label_col, F.array([df[label_col]]))
|
||||||
|
df = df.withColumn(prediction_col, F.array([df[prediction_col]]))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_label_from_probability(df, probability_col, prediction_col):
|
||||||
|
# array_max finds the maximum value in the 'probability' array
|
||||||
|
# array_position finds the index of the maximum value in the 'probability' array
|
||||||
|
max_index_expr = F.expr(
|
||||||
|
f"array_position({probability_col}, array_max({probability_col}))-1"
|
||||||
|
)
|
||||||
|
# Create a new column 'prediction' based on the maximum probability value
|
||||||
|
df = df.withColumn(prediction_col, max_index_expr.cast("double"))
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def spark_metric_loss_score(
|
||||||
|
metric_name: str,
|
||||||
|
y_predict: ps.Series,
|
||||||
|
y_true: ps.Series,
|
||||||
|
sample_weight: ps.Series = None,
|
||||||
|
groups: ps.Series = None,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute the loss score of a metric for spark models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metric_name: str | the name of the metric.
|
||||||
|
y_predict: ps.Series | the predicted values.
|
||||||
|
y_true: ps.Series | the true values.
|
||||||
|
sample_weight: ps.Series | the sample weights. Default: None.
|
||||||
|
groups: ps.Series | the group of each row. Default: None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float | the loss score. A lower value indicates a better model.
|
||||||
|
"""
|
||||||
|
label_col = "label"
|
||||||
|
prediction_col = "prediction"
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
|
y_predict.name = prediction_col
|
||||||
|
y_true.name = label_col
|
||||||
|
df = y_predict.to_frame().join(y_true)
|
||||||
|
if sample_weight is not None:
|
||||||
|
sample_weight.name = "weight"
|
||||||
|
df = df.join(sample_weight)
|
||||||
|
kwargs = {"weightCol": "weight"}
|
||||||
|
|
||||||
|
df = df.to_spark()
|
||||||
|
|
||||||
|
metric_name = metric_name.lower()
|
||||||
|
min_mode_metrics = ["log_loss", "rmse", "mse", "mae"]
|
||||||
|
|
||||||
|
if metric_name == "rmse":
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
metricName="rmse",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "mse":
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
metricName="mse",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "mae":
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
metricName="mae",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "r2":
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
metricName="r2",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "var":
|
||||||
|
evaluator = RegressionEvaluator(
|
||||||
|
metricName="var",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "roc_auc":
|
||||||
|
evaluator = BinaryClassificationEvaluator(
|
||||||
|
metricName="areaUnderROC",
|
||||||
|
labelCol=label_col,
|
||||||
|
rawPredictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "pr_auc":
|
||||||
|
evaluator = BinaryClassificationEvaluator(
|
||||||
|
metricName="areaUnderPR",
|
||||||
|
labelCol=label_col,
|
||||||
|
rawPredictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "accuracy":
|
||||||
|
evaluator = MulticlassClassificationEvaluator(
|
||||||
|
metricName="accuracy",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "log_loss":
|
||||||
|
# For log_loss, prediction_col should be probability, and we need to convert it to label
|
||||||
|
df = _compute_label_from_probability(
|
||||||
|
df, prediction_col, prediction_col + "_label"
|
||||||
|
)
|
||||||
|
evaluator = MulticlassClassificationEvaluator(
|
||||||
|
metricName="logLoss",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col + "_label",
|
||||||
|
probabilityCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "f1":
|
||||||
|
evaluator = MulticlassClassificationEvaluator(
|
||||||
|
metricName="f1",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "micro_f1":
|
||||||
|
evaluator = MultilabelClassificationEvaluator(
|
||||||
|
metricName="microF1Measure",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "macro_f1":
|
||||||
|
evaluator = MultilabelClassificationEvaluator(
|
||||||
|
metricName="f1MeasureByLabel",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif metric_name == "ap":
|
||||||
|
evaluator = RankingEvaluator(
|
||||||
|
metricName="meanAveragePrecision",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
)
|
||||||
|
elif "ndcg" in metric_name:
|
||||||
|
# TODO: check if spark.ml ranker has the same format with
|
||||||
|
# synapseML ranker, may need to adjust the format of df
|
||||||
|
if "@" in metric_name:
|
||||||
|
k = int(metric_name.split("@", 1)[-1])
|
||||||
|
if groups is None:
|
||||||
|
evaluator = RankingEvaluator(
|
||||||
|
metricName="ndcgAtK",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
df = _process_df(df, label_col, prediction_col)
|
||||||
|
score = 1 - evaluator.evaluate(df)
|
||||||
|
else:
|
||||||
|
counts = ps_group_counts(groups)
|
||||||
|
score = 0
|
||||||
|
psum = 0
|
||||||
|
for c in counts:
|
||||||
|
y_true_ = y_true[psum : psum + c]
|
||||||
|
y_predict_ = y_predict[psum : psum + c]
|
||||||
|
df = y_true_.to_frame().join(y_predict_).to_spark()
|
||||||
|
df = _process_df(df, label_col, prediction_col)
|
||||||
|
evaluator = RankingEvaluator(
|
||||||
|
metricName="ndcgAtK",
|
||||||
|
labelCol=label_col,
|
||||||
|
predictionCol=prediction_col,
|
||||||
|
k=k,
|
||||||
|
)
|
||||||
|
score -= evaluator.evaluate(df)
|
||||||
|
psum += c
|
||||||
|
score /= len(counts)
|
||||||
|
score += 1
|
||||||
|
else:
|
||||||
|
evaluator = RankingEvaluator(
|
||||||
|
metricName="ndcgAtK", labelCol=label_col, predictionCol=prediction_col
|
||||||
|
)
|
||||||
|
df = _process_df(df, label_col, prediction_col)
|
||||||
|
score = 1 - evaluator.evaluate(df)
|
||||||
|
return score
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown metric name: {metric_name} for spark models.")
|
||||||
|
|
||||||
|
return (
|
||||||
|
evaluator.evaluate(df)
|
||||||
|
if metric_name in min_mode_metrics
|
||||||
|
else 1 - evaluator.evaluate(df)
|
||||||
|
)
|
|
@ -0,0 +1,264 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Union, List, Optional, Tuple
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger_formatter = logging.Formatter(
|
||||||
|
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
|
||||||
|
)
|
||||||
|
logger.propagate = False
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
from pyspark.sql import DataFrame
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.util import VersionUtils
|
||||||
|
import pyspark.sql.functions as F
|
||||||
|
import pyspark.sql.types as T
|
||||||
|
import pyspark
|
||||||
|
|
||||||
|
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
|
||||||
|
except ImportError:
|
||||||
|
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
|
||||||
|
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
|
||||||
|
for more details about installing Spark."""
|
||||||
|
raise ImportError(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def to_pandas_on_spark(
|
||||||
|
df: Union[pd.DataFrame, DataFrame, pd.Series, ps.DataFrame, ps.Series],
|
||||||
|
index_col: Optional[str] = None,
|
||||||
|
default_index_type: Optional[str] = "distributed-sequence",
|
||||||
|
) -> Union[ps.DataFrame, ps.Series]:
|
||||||
|
"""Convert pandas or pyspark dataframe/series to pandas_on_Spark dataframe/series.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: pandas.DataFrame/series or pyspark dataframe | The input dataframe/series.
|
||||||
|
index_col: str, optional | The column name to use as index, default None.
|
||||||
|
default_index_type: str, optional | The default index type, default "distributed-sequence".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyspark.pandas.DataFrame/Series: The converted pandas-on-Spark dataframe/series.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import pandas as pd
|
||||||
|
from flaml.automl.spark.utils import to_pandas_on_spark
|
||||||
|
|
||||||
|
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||||
|
psdf = to_pandas_on_spark(pdf)
|
||||||
|
print(psdf)
|
||||||
|
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
sdf = spark.createDataFrame(pdf)
|
||||||
|
psdf = to_pandas_on_spark(sdf)
|
||||||
|
print(psdf)
|
||||||
|
|
||||||
|
pds = pd.Series([1, 2, 3])
|
||||||
|
pss = to_pandas_on_spark(pds)
|
||||||
|
print(pss)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
ps.set_option("compute.default_index_type", default_index_type)
|
||||||
|
if isinstance(df, (pd.DataFrame, pd.Series)):
|
||||||
|
return ps.from_pandas(df)
|
||||||
|
elif isinstance(df, DataFrame):
|
||||||
|
if _spark_major_minor_version[0] == 3 and _spark_major_minor_version[1] < 3:
|
||||||
|
return df.to_pandas_on_spark(index_col=index_col)
|
||||||
|
else:
|
||||||
|
return df.pandas_api(index_col=index_col)
|
||||||
|
elif isinstance(df, (ps.DataFrame, ps.Series)):
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{type(df)} is not one of pandas.DataFrame, pandas.Series and pyspark.sql.DataFrame"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_test_split_pyspark(
|
||||||
|
df: Union[DataFrame, ps.DataFrame],
|
||||||
|
stratify_column: Optional[str] = None,
|
||||||
|
test_fraction: Optional[float] = 0.2,
|
||||||
|
seed: Optional[int] = 1234,
|
||||||
|
to_pandas_spark: Optional[bool] = True,
|
||||||
|
index_col: Optional[str] = "tmp_index_col",
|
||||||
|
) -> Tuple[Union[DataFrame, ps.DataFrame], Union[DataFrame, ps.DataFrame]]:
|
||||||
|
"""Split a pyspark dataframe into train and test dataframes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: pyspark.sql.DataFrame | The input dataframe.
|
||||||
|
stratify_column: str | The column name to stratify the split. Default None.
|
||||||
|
test_fraction: float | The fraction of the test data. Default 0.2.
|
||||||
|
seed: int | The random seed. Default 1234.
|
||||||
|
to_pandas_spark: bool | Whether to convert the output to pandas_on_spark. Default True.
|
||||||
|
index_col: str | The column name to use as index. Default None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pyspark.sql.DataFrame/pandas_on_spark DataFrame | The train dataframe.
|
||||||
|
pyspark.sql.DataFrame/pandas_on_spark DataFrame | The test dataframe.
|
||||||
|
"""
|
||||||
|
if isinstance(df, ps.DataFrame):
|
||||||
|
df = df.to_spark(index_col=index_col)
|
||||||
|
|
||||||
|
if stratify_column:
|
||||||
|
# Test data
|
||||||
|
test_fraction_dict = (
|
||||||
|
df.select(stratify_column)
|
||||||
|
.distinct()
|
||||||
|
.withColumn("fraction", F.lit(test_fraction))
|
||||||
|
.rdd.collectAsMap()
|
||||||
|
)
|
||||||
|
df_test = df.stat.sampleBy(stratify_column, test_fraction_dict, seed)
|
||||||
|
# Train data
|
||||||
|
df_train = df.subtract(df_test)
|
||||||
|
else:
|
||||||
|
df_train, df_test = df.randomSplit([1 - test_fraction, test_fraction], seed)
|
||||||
|
|
||||||
|
if to_pandas_spark:
|
||||||
|
df_train = to_pandas_on_spark(df_train, index_col=index_col)
|
||||||
|
df_test = to_pandas_on_spark(df_test, index_col=index_col)
|
||||||
|
df_train.index.name = None
|
||||||
|
df_test.index.name = None
|
||||||
|
elif index_col == "tmp_index_col":
|
||||||
|
df_train = df_train.drop(index_col)
|
||||||
|
df_test = df_test.drop(index_col)
|
||||||
|
return [df_train, df_test]
|
||||||
|
|
||||||
|
|
||||||
|
def unique_pandas_on_spark(
|
||||||
|
psds: Union[ps.Series, ps.DataFrame]
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Get the unique values and counts of a pandas_on_spark series."""
|
||||||
|
if isinstance(psds, ps.DataFrame):
|
||||||
|
psds = psds.iloc[:, 0]
|
||||||
|
_tmp = psds.value_counts().to_pandas()
|
||||||
|
label_set = _tmp.index.values
|
||||||
|
counts = _tmp.values
|
||||||
|
return label_set, counts
|
||||||
|
|
||||||
|
|
||||||
|
def len_labels(
|
||||||
|
y: Union[ps.Series, np.ndarray], return_labels=False
|
||||||
|
) -> Union[int, Optional[np.ndarray]]:
|
||||||
|
"""Get the number of unique labels in y."""
|
||||||
|
if not isinstance(y, (ps.DataFrame, ps.Series)):
|
||||||
|
labels = np.unique(y)
|
||||||
|
else:
|
||||||
|
labels = y.unique() if isinstance(y, ps.Series) else y.iloc[:, 0].unique()
|
||||||
|
if return_labels:
|
||||||
|
return len(labels), labels
|
||||||
|
return len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
def unique_value_first_index(
|
||||||
|
y: Union[pd.Series, ps.Series, np.ndarray]
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Get the unique values and indices of a pandas series,
|
||||||
|
pandas_on_spark series or numpy array."""
|
||||||
|
if isinstance(y, ps.Series):
|
||||||
|
y_unique = y.drop_duplicates().sort_index()
|
||||||
|
label_set = y_unique.values
|
||||||
|
first_index = y_unique.index.values
|
||||||
|
else:
|
||||||
|
label_set, first_index = np.unique(y, return_index=True)
|
||||||
|
return label_set, first_index
|
||||||
|
|
||||||
|
|
||||||
|
def iloc_pandas_on_spark(
|
||||||
|
psdf: Union[ps.DataFrame, ps.Series, pd.DataFrame, pd.Series],
|
||||||
|
index: Union[int, slice, list],
|
||||||
|
index_col: Optional[str] = "tmp_index_col",
|
||||||
|
) -> Union[ps.DataFrame, ps.Series]:
|
||||||
|
"""Get the rows of a pandas_on_spark dataframe/series by index."""
|
||||||
|
if isinstance(psdf, (pd.DataFrame, pd.Series)):
|
||||||
|
return psdf.iloc[index]
|
||||||
|
if isinstance(index, (int, slice)):
|
||||||
|
if isinstance(psdf, ps.Series):
|
||||||
|
return psdf.iloc[index]
|
||||||
|
else:
|
||||||
|
return psdf.iloc[index, :]
|
||||||
|
elif isinstance(index, list):
|
||||||
|
if isinstance(psdf, ps.Series):
|
||||||
|
sdf = psdf.to_frame().to_spark(index_col=index_col)
|
||||||
|
else:
|
||||||
|
if index_col not in psdf.columns:
|
||||||
|
sdf = psdf.to_spark(index_col=index_col)
|
||||||
|
else:
|
||||||
|
sdf = psdf.to_spark()
|
||||||
|
sdfiloc = sdf.filter(F.col(index_col).isin(index))
|
||||||
|
psdfiloc = to_pandas_on_spark(sdfiloc)
|
||||||
|
if isinstance(psdf, ps.Series):
|
||||||
|
psdfiloc = psdfiloc[psdfiloc.columns.drop(index_col)[0]]
|
||||||
|
elif index_col not in psdf.columns:
|
||||||
|
psdfiloc = psdfiloc.drop(columns=[index_col])
|
||||||
|
return psdfiloc
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"{type(index)} is not one of int, slice and list for pandas_on_spark iloc"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def spark_kFold(
|
||||||
|
dataset: Union[DataFrame, ps.DataFrame],
|
||||||
|
nFolds: int = 3,
|
||||||
|
foldCol: str = "",
|
||||||
|
seed: int = 42,
|
||||||
|
index_col: Optional[str] = "tmp_index_col",
|
||||||
|
) -> List[Tuple[ps.DataFrame, ps.DataFrame]]:
|
||||||
|
"""Generate k-fold splits for a Spark DataFrame.
|
||||||
|
Adopted from https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/tuning.html#CrossValidator
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: DataFrame / ps.DataFrame. | The DataFrame to split.
|
||||||
|
nFolds: int | The number of folds. Default is 3.
|
||||||
|
foldCol: str | The column name to use for fold numbers. If not specified,
|
||||||
|
the DataFrame will be randomly split. Default is "".
|
||||||
|
The same group will not appear in two different folds (the number of
|
||||||
|
distinct groups has to be at least equal to the number of folds).
|
||||||
|
The folds are approximately balanced in the sense that the number of
|
||||||
|
distinct groups is approximately the same in each fold.
|
||||||
|
seed: int | The random seed. Default is 42.
|
||||||
|
index_col: str | The name of the index column. Default is "tmp_index_col".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (train, validation) DataFrames.
|
||||||
|
"""
|
||||||
|
if isinstance(dataset, ps.DataFrame):
|
||||||
|
dataset = dataset.to_spark(index_col=index_col)
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
if not foldCol:
|
||||||
|
# Do random k-fold split.
|
||||||
|
h = 1.0 / nFolds
|
||||||
|
randCol = f"rand_col_{seed}"
|
||||||
|
df = dataset.select("*", F.rand(seed).alias(randCol))
|
||||||
|
for i in range(nFolds):
|
||||||
|
validateLB = i * h
|
||||||
|
validateUB = (i + 1) * h
|
||||||
|
condition = (df[randCol] >= validateLB) & (df[randCol] < validateUB)
|
||||||
|
validation = to_pandas_on_spark(df.filter(condition), index_col=index_col)
|
||||||
|
train = to_pandas_on_spark(df.filter(~condition), index_col=index_col)
|
||||||
|
datasets.append(
|
||||||
|
(train.drop(columns=[randCol]), validation.drop(columns=[randCol]))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use user-specified fold column
|
||||||
|
def get_fold_num(foldNum: int) -> int:
|
||||||
|
return int(foldNum % nFolds)
|
||||||
|
|
||||||
|
get_fold_num_udf = F.UserDefinedFunction(get_fold_num, T.IntegerType())
|
||||||
|
for i in range(nFolds):
|
||||||
|
training = dataset.filter(get_fold_num_udf(dataset[foldCol]) != F.lit(i))
|
||||||
|
validation = dataset.filter(get_fold_num_udf(dataset[foldCol]) == F.lit(i))
|
||||||
|
if training.rdd.getNumPartitions() == 0 or len(training.take(1)) == 0:
|
||||||
|
raise ValueError("The training data at fold %s is empty." % i)
|
||||||
|
if validation.rdd.getNumPartitions() == 0 or len(validation.take(1)) == 0:
|
||||||
|
raise ValueError("The validation data at fold %s is empty." % i)
|
||||||
|
training = to_pandas_on_spark(training, index_col=index_col)
|
||||||
|
validation = to_pandas_on_spark(validation, index_col=index_col)
|
||||||
|
datasets.append((training, validation))
|
||||||
|
|
||||||
|
return datasets
|
|
@ -1,5 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
import time
|
import time
|
||||||
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -10,6 +11,34 @@ from flaml.automl.logger import logger
|
||||||
from flaml.automl.ml import compute_estimator, train_estimator
|
from flaml.automl.ml import compute_estimator, train_estimator
|
||||||
from flaml.automl.task.task import TS_FORECAST
|
from flaml.automl.task.task import TS_FORECAST
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flaml.automl.spark.utils import (
|
||||||
|
train_test_split_pyspark,
|
||||||
|
unique_pandas_on_spark,
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
train_test_split_pyspark = None
|
||||||
|
unique_pandas_on_spark = None
|
||||||
|
from flaml.automl.utils import (
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
from pyspark.pandas.config import set_option, reset_option
|
||||||
|
except ImportError:
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SearchState:
|
class SearchState:
|
||||||
@property
|
@property
|
||||||
|
@ -241,11 +270,11 @@ class AutoMLState:
|
||||||
def _prepare_sample_train_data(self, sample_size: int):
|
def _prepare_sample_train_data(self, sample_size: int):
|
||||||
sampled_weight = groups = None
|
sampled_weight = groups = None
|
||||||
if sample_size <= self.data_size[0]:
|
if sample_size <= self.data_size[0]:
|
||||||
if isinstance(self.X_train, pd.DataFrame):
|
if isinstance(self.X_train, (pd.DataFrame, psDataFrame)):
|
||||||
sampled_X_train = self.X_train.iloc[:sample_size]
|
sampled_X_train = self.X_train.iloc[:sample_size]
|
||||||
else:
|
else:
|
||||||
sampled_X_train = self.X_train[:sample_size]
|
sampled_X_train = self.X_train[:sample_size]
|
||||||
if isinstance(self.y_train, pd.Series):
|
if isinstance(self.y_train, (pd.Series, psSeries)):
|
||||||
sampled_y_train = self.y_train.iloc[:sample_size]
|
sampled_y_train = self.y_train.iloc[:sample_size]
|
||||||
else:
|
else:
|
||||||
sampled_y_train = self.y_train[:sample_size]
|
sampled_y_train = self.y_train[:sample_size]
|
||||||
|
@ -255,13 +284,13 @@ class AutoMLState:
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
sampled_weight = (
|
sampled_weight = (
|
||||||
weight.iloc[:sample_size]
|
weight.iloc[:sample_size]
|
||||||
if isinstance(weight, pd.Series)
|
if isinstance(weight, (pd.Series, psSeries))
|
||||||
else weight[:sample_size]
|
else weight[:sample_size]
|
||||||
)
|
)
|
||||||
if self.groups is not None:
|
if self.groups is not None:
|
||||||
groups = (
|
groups = (
|
||||||
self.groups.iloc[:sample_size]
|
self.groups.iloc[:sample_size]
|
||||||
if isinstance(self.groups, pd.Series)
|
if isinstance(self.groups, (pd.Series, psSeries))
|
||||||
else self.groups[:sample_size]
|
else self.groups[:sample_size]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
@ -30,6 +31,7 @@ from flaml.automl.model import (
|
||||||
KNeighborsEstimator,
|
KNeighborsEstimator,
|
||||||
TransformersEstimator,
|
TransformersEstimator,
|
||||||
TransformersEstimatorModelSelection,
|
TransformersEstimatorModelSelection,
|
||||||
|
SparkLGBMEstimator,
|
||||||
)
|
)
|
||||||
from flaml.automl.task.task import (
|
from flaml.automl.task.task import (
|
||||||
Task,
|
Task,
|
||||||
|
@ -39,6 +41,40 @@ from flaml.automl.task.task import (
|
||||||
)
|
)
|
||||||
from flaml.config import RANDOM_SEED
|
from flaml.config import RANDOM_SEED
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
from pyspark.sql.functions import col
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
from pyspark.pandas.config import set_option, reset_option
|
||||||
|
from flaml.automl.spark.utils import (
|
||||||
|
to_pandas_on_spark,
|
||||||
|
iloc_pandas_on_spark,
|
||||||
|
spark_kFold,
|
||||||
|
train_test_split_pyspark,
|
||||||
|
unique_pandas_on_spark,
|
||||||
|
unique_value_first_index,
|
||||||
|
len_labels,
|
||||||
|
)
|
||||||
|
from flaml.automl.spark.metrics import spark_metric_loss_score
|
||||||
|
except ImportError:
|
||||||
|
train_test_split_pyspark = None
|
||||||
|
unique_pandas_on_spark = None
|
||||||
|
iloc_pandas_on_spark = None
|
||||||
|
from flaml.automl.utils import (
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,6 +91,7 @@ class GenericTask(Task):
|
||||||
"kneighbor": KNeighborsEstimator,
|
"kneighbor": KNeighborsEstimator,
|
||||||
"transformer": TransformersEstimator,
|
"transformer": TransformersEstimator,
|
||||||
"transformer_ms": TransformersEstimatorModelSelection,
|
"transformer_ms": TransformersEstimatorModelSelection,
|
||||||
|
"lgbm_spark": SparkLGBMEstimator,
|
||||||
}
|
}
|
||||||
|
|
||||||
def validate_data(
|
def validate_data(
|
||||||
|
@ -71,17 +108,15 @@ class GenericTask(Task):
|
||||||
groups=None,
|
groups=None,
|
||||||
):
|
):
|
||||||
if X_train_all is not None and y_train_all is not None:
|
if X_train_all is not None and y_train_all is not None:
|
||||||
assert (
|
assert isinstance(
|
||||||
isinstance(X_train_all, np.ndarray)
|
X_train_all, (np.ndarray, pd.DataFrame, psDataFrame)
|
||||||
or issparse(X_train_all)
|
) or issparse(X_train_all), (
|
||||||
or isinstance(X_train_all, pd.DataFrame)
|
|
||||||
), (
|
|
||||||
"X_train_all must be a numpy array, a pandas dataframe, "
|
"X_train_all must be a numpy array, a pandas dataframe, "
|
||||||
"or Scipy sparse matrix."
|
"a Scipy sparse matrix or a pyspark.pandas dataframe."
|
||||||
)
|
)
|
||||||
assert isinstance(y_train_all, np.ndarray) or isinstance(
|
assert isinstance(
|
||||||
y_train_all, pd.Series
|
y_train_all, (np.ndarray, pd.Series, psSeries)
|
||||||
), "y_train_all must be a numpy array or a pandas series."
|
), "y_train_all must be a numpy array, a pandas series or a pyspark.pandas series."
|
||||||
assert (
|
assert (
|
||||||
X_train_all.size != 0 and y_train_all.size != 0
|
X_train_all.size != 0 and y_train_all.size != 0
|
||||||
), "Input data must not be empty."
|
), "Input data must not be empty."
|
||||||
|
@ -92,22 +127,42 @@ class GenericTask(Task):
|
||||||
assert (
|
assert (
|
||||||
X_train_all.shape[0] == y_train_all.shape[0]
|
X_train_all.shape[0] == y_train_all.shape[0]
|
||||||
), "# rows in X_train must match length of y_train."
|
), "# rows in X_train must match length of y_train."
|
||||||
automl._df = isinstance(X_train_all, pd.DataFrame)
|
if isinstance(X_train_all, psDataFrame):
|
||||||
|
X_train_all = (
|
||||||
|
X_train_all.spark.cache()
|
||||||
|
) # cache data to improve compute speed
|
||||||
|
y_train_all = y_train_all.to_frame().spark.cache()[y_train_all.name]
|
||||||
|
logger.debug(
|
||||||
|
f"X_train_all and y_train_all cached, shape of X_train_all: {X_train_all.shape}"
|
||||||
|
)
|
||||||
|
automl._df = isinstance(X_train_all, (pd.DataFrame, psDataFrame))
|
||||||
automl._nrow, automl._ndim = X_train_all.shape
|
automl._nrow, automl._ndim = X_train_all.shape
|
||||||
if self.is_ts_forecast():
|
if self.is_ts_forecast():
|
||||||
X_train_all = pd.DataFrame(X_train_all)
|
X_train_all = (
|
||||||
|
pd.DataFrame(X_train_all)
|
||||||
|
if isinstance(X_train_all, np.ndarray)
|
||||||
|
else X_train_all
|
||||||
|
)
|
||||||
X_train_all, y_train_all = self._validate_ts_data(
|
X_train_all, y_train_all = self._validate_ts_data(
|
||||||
X_train_all, y_train_all
|
X_train_all, y_train_all
|
||||||
)
|
)
|
||||||
X, y = X_train_all, y_train_all
|
X, y = X_train_all, y_train_all
|
||||||
elif dataframe is not None and label is not None:
|
elif dataframe is not None and label is not None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
dataframe, pd.DataFrame
|
dataframe, (pd.DataFrame, psDataFrame)
|
||||||
), "dataframe must be a pandas DataFrame"
|
), "dataframe must be a pandas DataFrame or a pyspark.pandas DataFrame."
|
||||||
assert label in dataframe.columns, "label must a column name in dataframe"
|
assert (
|
||||||
|
label in dataframe.columns
|
||||||
|
), f"The provided label column name `{label}` doesn't exist in the provided dataframe."
|
||||||
|
if isinstance(dataframe, psDataFrame):
|
||||||
|
dataframe = (
|
||||||
|
dataframe.spark.cache()
|
||||||
|
) # cache data to improve compute speed
|
||||||
|
logger.debug(f"dataframe cached, shape of dataframe: {dataframe.shape}")
|
||||||
automl._df = True
|
automl._df = True
|
||||||
if self.is_ts_forecast():
|
if self.is_ts_forecast():
|
||||||
dataframe = self._validate_ts_data(dataframe)
|
dataframe = self._validate_ts_data(dataframe)
|
||||||
|
# TODO: to support pyspark.sql.DataFrame and pure dataframe mode
|
||||||
X = dataframe.drop(columns=label)
|
X = dataframe.drop(columns=label)
|
||||||
automl._nrow, automl._ndim = X.shape
|
automl._nrow, automl._ndim = X.shape
|
||||||
y = dataframe[label]
|
y = dataframe[label]
|
||||||
|
@ -125,7 +180,7 @@ class GenericTask(Task):
|
||||||
"object",
|
"object",
|
||||||
"string",
|
"string",
|
||||||
), "If the task is an NLP task, X can only contain text columns"
|
), "If the task is an NLP task, X can only contain text columns"
|
||||||
for each_cell in X[column]:
|
for _, each_cell in X[column].items():
|
||||||
if each_cell is not None:
|
if each_cell is not None:
|
||||||
is_str = isinstance(each_cell, str)
|
is_str = isinstance(each_cell, str)
|
||||||
is_list_of_int = isinstance(each_cell, list) and all(
|
is_list_of_int = isinstance(each_cell, list) and all(
|
||||||
|
@ -149,8 +204,10 @@ class GenericTask(Task):
|
||||||
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
|
"Currently FLAML only supports two modes for NLP: either all columns of X are string (non-tokenized), "
|
||||||
"or all columns of X are integer ids (tokenized)"
|
"or all columns of X are integer ids (tokenized)"
|
||||||
)
|
)
|
||||||
|
if isinstance(X, psDataFrame):
|
||||||
if issparse(X_train_all) or automl._skip_transform:
|
# TODO: support pyspark.pandas dataframe in DataTransformer
|
||||||
|
automl._skip_transform = True
|
||||||
|
if automl._skip_transform or issparse(X_train_all):
|
||||||
automl._transformer = automl._label_transformer = False
|
automl._transformer = automl._label_transformer = False
|
||||||
automl._X_train_all, automl._y_train_all = X, y
|
automl._X_train_all, automl._y_train_all = X, y
|
||||||
else:
|
else:
|
||||||
|
@ -184,17 +241,16 @@ class GenericTask(Task):
|
||||||
"sample_weight"
|
"sample_weight"
|
||||||
) # NOTE: _validate_data is before kwargs is updated to fit_kwargs_by_estimator
|
) # NOTE: _validate_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||||
if X_val is not None and y_val is not None:
|
if X_val is not None and y_val is not None:
|
||||||
assert (
|
assert isinstance(
|
||||||
isinstance(X_val, np.ndarray)
|
X_val, (np.ndarray, pd.DataFrame, psDataFrame)
|
||||||
or issparse(X_val)
|
) or issparse(X_train_all), (
|
||||||
or isinstance(X_val, pd.DataFrame)
|
|
||||||
), (
|
|
||||||
"X_val must be None, a numpy array, a pandas dataframe, "
|
"X_val must be None, a numpy array, a pandas dataframe, "
|
||||||
"or Scipy sparse matrix."
|
"a Scipy sparse matrix or a pyspark.pandas dataframe."
|
||||||
|
)
|
||||||
|
assert isinstance(y_val, (np.ndarray, pd.Series, psSeries)), (
|
||||||
|
"y_val must be None, a numpy array, a pandas series "
|
||||||
|
"or a pyspark.pandas series."
|
||||||
)
|
)
|
||||||
assert isinstance(y_val, np.ndarray) or isinstance(
|
|
||||||
y_val, pd.Series
|
|
||||||
), "y_val must be None, a numpy array or a pandas series."
|
|
||||||
assert X_val.size != 0 and y_val.size != 0, (
|
assert X_val.size != 0 and y_val.size != 0, (
|
||||||
"Validation data are expected to be nonempty. "
|
"Validation data are expected to be nonempty. "
|
||||||
"Use None for X_val and y_val if no validation data."
|
"Use None for X_val and y_val if no validation data."
|
||||||
|
@ -241,25 +297,39 @@ class GenericTask(Task):
|
||||||
dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]"
|
dataframe[dataframe.columns[0]].dtype.name == "datetime64[ns]"
|
||||||
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
|
), f"For '{TS_FORECAST}' task, the first column must contain timestamp values."
|
||||||
if y_train_all is not None:
|
if y_train_all is not None:
|
||||||
y_df = (
|
if isinstance(y_train_all, pd.Series):
|
||||||
pd.DataFrame(y_train_all)
|
y_df = pd.DataFrame(y_train_all)
|
||||||
if isinstance(y_train_all, pd.Series)
|
elif isinstance(y_train_all, np.ndarray):
|
||||||
else pd.DataFrame(y_train_all, columns=["labels"])
|
y_df = pd.DataFrame(y_train_all, columns=["labels"])
|
||||||
)
|
elif isinstance(y_train_all, (psDataFrame, psSeries)):
|
||||||
|
# TODO: optimize this
|
||||||
|
set_option("compute.ops_on_diff_frames", True)
|
||||||
|
y_df = y_train_all
|
||||||
dataframe = dataframe.join(y_df)
|
dataframe = dataframe.join(y_df)
|
||||||
duplicates = dataframe.duplicated()
|
duplicates = dataframe.duplicated()
|
||||||
if any(duplicates):
|
if isinstance(dataframe, psDataFrame):
|
||||||
logger.warning(
|
if duplicates.any():
|
||||||
"Duplicate timestamp values found in timestamp column. "
|
logger.warning("Duplicate timestamp values found in timestamp column.")
|
||||||
f"\n{dataframe.loc[duplicates, dataframe][dataframe.columns[0]]}"
|
dataframe = dataframe.drop_duplicates()
|
||||||
)
|
logger.warning("Removed duplicate rows based on all columns")
|
||||||
dataframe = dataframe.drop_duplicates()
|
assert (
|
||||||
logger.warning("Removed duplicate rows based on all columns")
|
dataframe[[dataframe.columns[0]]].duplicated().any() is False
|
||||||
assert (
|
), "Duplicate timestamp values with different values for other columns."
|
||||||
dataframe[[dataframe.columns[0]]].duplicated() is None
|
ts_series = ps.to_datetime(dataframe[dataframe.columns[0]])
|
||||||
), "Duplicate timestamp values with different values for other columns."
|
inferred_freq = None # TODO: `pd.infer_freq()` is not implemented yet.
|
||||||
ts_series = pd.to_datetime(dataframe[dataframe.columns[0]])
|
else:
|
||||||
inferred_freq = pd.infer_freq(ts_series)
|
if any(duplicates):
|
||||||
|
logger.warning(
|
||||||
|
"Duplicate timestamp values found in timestamp column. "
|
||||||
|
f"\n{dataframe.loc[duplicates, dataframe][dataframe.columns[0]]}"
|
||||||
|
)
|
||||||
|
dataframe = dataframe.drop_duplicates()
|
||||||
|
logger.warning("Removed duplicate rows based on all columns")
|
||||||
|
assert (
|
||||||
|
dataframe[[dataframe.columns[0]]].duplicated() is None
|
||||||
|
), "Duplicate timestamp values with different values for other columns."
|
||||||
|
ts_series = pd.to_datetime(dataframe[dataframe.columns[0]])
|
||||||
|
inferred_freq = pd.infer_freq(ts_series)
|
||||||
if inferred_freq is None:
|
if inferred_freq is None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Missing timestamps detected. To avoid error with estimators, set estimator list to ['prophet']. "
|
"Missing timestamps detected. To avoid error with estimators, set estimator list to ['prophet']. "
|
||||||
|
@ -268,6 +338,121 @@ class GenericTask(Task):
|
||||||
return dataframe.iloc[:, :-1], dataframe.iloc[:, -1]
|
return dataframe.iloc[:, :-1], dataframe.iloc[:, -1]
|
||||||
return dataframe
|
return dataframe
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _split_pyspark(state, X_train_all, y_train_all, split_ratio, stratify=None):
|
||||||
|
# TODO: optimize this
|
||||||
|
set_option("compute.ops_on_diff_frames", True)
|
||||||
|
if not isinstance(y_train_all, (psDataFrame, psSeries)):
|
||||||
|
raise ValueError("y_train_all must be a pyspark.pandas dataframe or series")
|
||||||
|
df_all_in_one = X_train_all.join(y_train_all)
|
||||||
|
stratify_column = (
|
||||||
|
y_train_all.name
|
||||||
|
if isinstance(y_train_all, psSeries)
|
||||||
|
else y_train_all.columns[0]
|
||||||
|
)
|
||||||
|
ret_sample_weight = False
|
||||||
|
if (
|
||||||
|
"sample_weight" in state.fit_kwargs
|
||||||
|
): # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||||
|
# fit_kwargs["sample_weight"] is an numpy array
|
||||||
|
ps_sample_weight = ps.DataFrame(
|
||||||
|
state.fit_kwargs["sample_weight"],
|
||||||
|
columns=["sample_weight"],
|
||||||
|
)
|
||||||
|
df_all_in_one = df_all_in_one.join(ps_sample_weight)
|
||||||
|
ret_sample_weight = True
|
||||||
|
df_all_train, df_all_val = train_test_split_pyspark(
|
||||||
|
df_all_in_one,
|
||||||
|
None if stratify is None else stratify_column,
|
||||||
|
test_fraction=split_ratio,
|
||||||
|
seed=RANDOM_SEED,
|
||||||
|
)
|
||||||
|
columns_to_drop = [
|
||||||
|
c for c in df_all_train.columns if c in [stratify_column, "sample_weight"]
|
||||||
|
]
|
||||||
|
X_train = df_all_train.drop(columns_to_drop)
|
||||||
|
X_val = df_all_val.drop(columns_to_drop)
|
||||||
|
y_train = df_all_train[stratify_column]
|
||||||
|
y_val = df_all_val[stratify_column]
|
||||||
|
|
||||||
|
if ret_sample_weight:
|
||||||
|
return (
|
||||||
|
X_train,
|
||||||
|
X_val,
|
||||||
|
y_train,
|
||||||
|
y_val,
|
||||||
|
df_all_train["sample_weight"],
|
||||||
|
df_all_val["sample_weight"],
|
||||||
|
)
|
||||||
|
return X_train, X_val, y_train, y_val
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _train_test_split(
|
||||||
|
state, X, y, first=None, rest=None, split_ratio=0.2, stratify=None
|
||||||
|
):
|
||||||
|
condition_type = isinstance(X, (psDataFrame, psSeries))
|
||||||
|
# NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||||
|
condition_param = "sample_weight" in state.fit_kwargs
|
||||||
|
if not condition_type and condition_param:
|
||||||
|
sample_weight = (
|
||||||
|
state.fit_kwargs["sample_weight"]
|
||||||
|
if rest is None
|
||||||
|
else state.fit_kwargs["sample_weight"][rest]
|
||||||
|
)
|
||||||
|
(
|
||||||
|
X_train,
|
||||||
|
X_val,
|
||||||
|
y_train,
|
||||||
|
y_val,
|
||||||
|
weight_train,
|
||||||
|
weight_val,
|
||||||
|
) = train_test_split(
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
sample_weight,
|
||||||
|
test_size=split_ratio,
|
||||||
|
stratify=stratify,
|
||||||
|
random_state=RANDOM_SEED,
|
||||||
|
)
|
||||||
|
|
||||||
|
if first is not None:
|
||||||
|
weight1 = state.fit_kwargs["sample_weight"][first]
|
||||||
|
state.weight_val = concat(weight1, weight_val)
|
||||||
|
state.fit_kwargs["sample_weight"] = concat(weight1, weight_train)
|
||||||
|
else:
|
||||||
|
state.weight_val = weight_val
|
||||||
|
state.fit_kwargs["sample_weight"] = weight_train
|
||||||
|
elif not condition_type and not condition_param:
|
||||||
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
|
X,
|
||||||
|
y,
|
||||||
|
test_size=split_ratio,
|
||||||
|
stratify=stratify,
|
||||||
|
random_state=RANDOM_SEED,
|
||||||
|
)
|
||||||
|
elif condition_type and condition_param:
|
||||||
|
(
|
||||||
|
X_train,
|
||||||
|
X_val,
|
||||||
|
y_train,
|
||||||
|
y_val,
|
||||||
|
weight_train,
|
||||||
|
weight_val,
|
||||||
|
) = GenericTask._split_pyspark(state, X, y, split_ratio, stratify)
|
||||||
|
|
||||||
|
if first is not None:
|
||||||
|
weight1 = state.fit_kwargs["sample_weight"][first]
|
||||||
|
state.weight_val = concat(weight1, weight_val)
|
||||||
|
state.fit_kwargs["sample_weight"] = concat(weight1, weight_train)
|
||||||
|
else:
|
||||||
|
state.weight_val = weight_val
|
||||||
|
state.fit_kwargs["sample_weight"] = weight_train
|
||||||
|
else:
|
||||||
|
X_train, X_val, y_train, y_val = GenericTask._split_pyspark(
|
||||||
|
state, X, y, split_ratio, stratify
|
||||||
|
)
|
||||||
|
return X_train, X_val, y_train, y_val
|
||||||
|
|
||||||
def prepare_data(
|
def prepare_data(
|
||||||
self,
|
self,
|
||||||
state,
|
state,
|
||||||
|
@ -286,6 +471,8 @@ class GenericTask(Task):
|
||||||
X_val = X_val.tocsr()
|
X_val = X_val.tocsr()
|
||||||
if issparse(X_train_all):
|
if issparse(X_train_all):
|
||||||
X_train_all = X_train_all.tocsr()
|
X_train_all = X_train_all.tocsr()
|
||||||
|
is_spark_dataframe = isinstance(X_train_all, (psDataFrame, psSeries))
|
||||||
|
self.is_spark_dataframe = is_spark_dataframe
|
||||||
if (
|
if (
|
||||||
self.is_classification()
|
self.is_classification()
|
||||||
and auto_augment
|
and auto_augment
|
||||||
|
@ -295,12 +482,17 @@ class GenericTask(Task):
|
||||||
and not self.is_token_classification()
|
and not self.is_token_classification()
|
||||||
):
|
):
|
||||||
# logger.info(f"label {pd.unique(y_train_all)}")
|
# logger.info(f"label {pd.unique(y_train_all)}")
|
||||||
label_set, counts = np.unique(y_train_all, return_counts=True)
|
if is_spark_dataframe:
|
||||||
|
label_set, counts = unique_pandas_on_spark(y_train_all)
|
||||||
|
# TODO: optimize this
|
||||||
|
set_option("compute.ops_on_diff_frames", True)
|
||||||
|
else:
|
||||||
|
label_set, counts = np.unique(y_train_all, return_counts=True)
|
||||||
# augment rare classes
|
# augment rare classes
|
||||||
rare_threshld = 20
|
rare_threshld = 20
|
||||||
rare = counts < rare_threshld
|
rare = counts < rare_threshld
|
||||||
rare_label, rare_counts = label_set[rare], counts[rare]
|
rare_label, rare_counts = label_set[rare], counts[rare]
|
||||||
for i, label in enumerate(rare_label):
|
for i, label in enumerate(rare_label.tolist()):
|
||||||
count = rare_count = rare_counts[i]
|
count = rare_count = rare_counts[i]
|
||||||
rare_index = y_train_all == label
|
rare_index = y_train_all == label
|
||||||
n = len(y_train_all)
|
n = len(y_train_all)
|
||||||
|
@ -313,7 +505,7 @@ class GenericTask(Task):
|
||||||
X_train_all = concat(
|
X_train_all = concat(
|
||||||
X_train_all, X_train_all[:n][rare_index, :]
|
X_train_all, X_train_all[:n][rare_index, :]
|
||||||
)
|
)
|
||||||
if isinstance(y_train_all, pd.Series):
|
if isinstance(y_train_all, (pd.Series, psSeries)):
|
||||||
y_train_all = concat(
|
y_train_all = concat(
|
||||||
y_train_all, y_train_all.iloc[:n].loc[rare_index]
|
y_train_all, y_train_all.iloc[:n].loc[rare_index]
|
||||||
)
|
)
|
||||||
|
@ -324,7 +516,10 @@ class GenericTask(Task):
|
||||||
count += rare_count
|
count += rare_count
|
||||||
logger.info(f"class {label} augmented from {rare_count} to {count}")
|
logger.info(f"class {label} augmented from {rare_count} to {count}")
|
||||||
SHUFFLE_SPLIT_TYPES = ["uniform", "stratified"]
|
SHUFFLE_SPLIT_TYPES = ["uniform", "stratified"]
|
||||||
if split_type in SHUFFLE_SPLIT_TYPES:
|
if is_spark_dataframe:
|
||||||
|
# no need to shuffle pyspark dataframe
|
||||||
|
pass
|
||||||
|
elif split_type in SHUFFLE_SPLIT_TYPES:
|
||||||
if sample_weight_full is not None:
|
if sample_weight_full is not None:
|
||||||
X_train_all, y_train_all, state.sample_weight_all = shuffle(
|
X_train_all, y_train_all, state.sample_weight_all = shuffle(
|
||||||
X_train_all,
|
X_train_all,
|
||||||
|
@ -363,18 +558,26 @@ class GenericTask(Task):
|
||||||
ids = state.fit_kwargs["group_ids"].copy()
|
ids = state.fit_kwargs["group_ids"].copy()
|
||||||
ids.append(TS_TIMESTAMP_COL)
|
ids.append(TS_TIMESTAMP_COL)
|
||||||
ids.append("time_idx")
|
ids.append("time_idx")
|
||||||
y_train_all = pd.DataFrame(y_train_all)
|
y_train_all = (
|
||||||
|
pd.DataFrame(y_train_all)
|
||||||
|
if not is_spark_dataframe
|
||||||
|
else ps.DataFrame(y_train_all)
|
||||||
|
if isinstance(y_train_all, psSeries)
|
||||||
|
else y_train_all
|
||||||
|
)
|
||||||
y_train_all[ids] = X_train_all[ids]
|
y_train_all[ids] = X_train_all[ids]
|
||||||
X_train_all = X_train_all.sort_values(ids)
|
X_train_all = X_train_all.sort_values(ids)
|
||||||
y_train_all = y_train_all.sort_values(ids)
|
y_train_all = y_train_all.sort_values(ids)
|
||||||
training_cutoff = X_train_all["time_idx"].max() - period
|
training_cutoff = X_train_all["time_idx"].max() - period
|
||||||
X_train = X_train_all[lambda x: x.time_idx <= training_cutoff]
|
X_train = X_train_all[
|
||||||
|
X_train_all["time_idx"] <= training_cutoff
|
||||||
|
]
|
||||||
y_train = y_train_all[
|
y_train = y_train_all[
|
||||||
lambda x: x.time_idx <= training_cutoff
|
y_train_all["time_idx"] <= training_cutoff
|
||||||
].drop(columns=ids)
|
].drop(columns=ids)
|
||||||
X_val = X_train_all[lambda x: x.time_idx > training_cutoff]
|
X_val = X_train_all[X_train_all["time_idx"] > training_cutoff]
|
||||||
y_val = y_train_all[
|
y_val = y_train_all[
|
||||||
lambda x: x.time_idx > training_cutoff
|
y_train_all["time_idx"] > training_cutoff
|
||||||
].drop(columns=ids)
|
].drop(columns=ids)
|
||||||
else:
|
else:
|
||||||
num_samples = X_train_all.shape[0]
|
num_samples = X_train_all.shape[0]
|
||||||
|
@ -387,9 +590,8 @@ class GenericTask(Task):
|
||||||
X_val = X_train_all[split_idx:]
|
X_val = X_train_all[split_idx:]
|
||||||
y_val = y_train_all[split_idx:]
|
y_val = y_train_all[split_idx:]
|
||||||
else:
|
else:
|
||||||
if (
|
is_sample_weight = "sample_weight" in state.fit_kwargs
|
||||||
"sample_weight" in state.fit_kwargs
|
if not is_spark_dataframe and is_sample_weight:
|
||||||
): # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
(
|
(
|
||||||
X_train,
|
X_train,
|
||||||
X_val,
|
X_val,
|
||||||
|
@ -408,13 +610,30 @@ class GenericTask(Task):
|
||||||
test_size=split_ratio,
|
test_size=split_ratio,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
else:
|
elif not is_spark_dataframe and not is_sample_weight:
|
||||||
X_train, X_val, y_train, y_val = train_test_split(
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
X_train_all,
|
X_train_all,
|
||||||
y_train_all,
|
y_train_all,
|
||||||
test_size=split_ratio,
|
test_size=split_ratio,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
)
|
)
|
||||||
|
elif is_spark_dataframe and is_sample_weight:
|
||||||
|
(
|
||||||
|
X_train,
|
||||||
|
X_val,
|
||||||
|
y_train,
|
||||||
|
y_val,
|
||||||
|
state.fit_kwargs[
|
||||||
|
"sample_weight"
|
||||||
|
], # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||||
|
state.weight_val,
|
||||||
|
) = self._split_pyspark(
|
||||||
|
state, X_train_all, y_train_all, split_ratio
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
X_train, X_val, y_train, y_val = self._split_pyspark(
|
||||||
|
state, X_train_all, y_train_all, split_ratio
|
||||||
|
)
|
||||||
elif split_type == "group":
|
elif split_type == "group":
|
||||||
gss = GroupShuffleSplit(
|
gss = GroupShuffleSplit(
|
||||||
n_splits=1, test_size=split_ratio, random_state=RANDOM_SEED
|
n_splits=1, test_size=split_ratio, random_state=RANDOM_SEED
|
||||||
|
@ -433,7 +652,7 @@ class GenericTask(Task):
|
||||||
elif self.is_classification():
|
elif self.is_classification():
|
||||||
# for classification, make sure the labels are complete in both
|
# for classification, make sure the labels are complete in both
|
||||||
# training and validation data
|
# training and validation data
|
||||||
label_set, first = np.unique(y_train_all, return_index=True)
|
label_set, first = unique_value_first_index(y_train_all)
|
||||||
rest = []
|
rest = []
|
||||||
last = 0
|
last = 0
|
||||||
first.sort()
|
first.sort()
|
||||||
|
@ -443,45 +662,17 @@ class GenericTask(Task):
|
||||||
rest.extend(range(last, len(y_train_all)))
|
rest.extend(range(last, len(y_train_all)))
|
||||||
X_first = X_train_all.iloc[first] if data_is_df else X_train_all[first]
|
X_first = X_train_all.iloc[first] if data_is_df else X_train_all[first]
|
||||||
X_rest = X_train_all.iloc[rest] if data_is_df else X_train_all[rest]
|
X_rest = X_train_all.iloc[rest] if data_is_df else X_train_all[rest]
|
||||||
y_rest = y_train_all[rest]
|
y_rest = (
|
||||||
|
y_train_all[rest]
|
||||||
|
if isinstance(y_train_all, np.ndarray)
|
||||||
|
else iloc_pandas_on_spark(y_train_all, rest)
|
||||||
|
if is_spark_dataframe
|
||||||
|
else y_train_all.iloc[rest]
|
||||||
|
)
|
||||||
stratify = y_rest if split_type == "stratified" else None
|
stratify = y_rest if split_type == "stratified" else None
|
||||||
if (
|
X_train, X_val, y_train, y_val = self._train_test_split(
|
||||||
"sample_weight" in state.fit_kwargs
|
state, X_rest, y_rest, first, rest, split_ratio, stratify
|
||||||
): # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
)
|
||||||
(
|
|
||||||
X_train,
|
|
||||||
X_val,
|
|
||||||
y_train,
|
|
||||||
y_val,
|
|
||||||
weight_train,
|
|
||||||
weight_val,
|
|
||||||
) = train_test_split(
|
|
||||||
X_rest,
|
|
||||||
y_rest,
|
|
||||||
state.fit_kwargs["sample_weight"][
|
|
||||||
rest
|
|
||||||
], # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
test_size=split_ratio,
|
|
||||||
stratify=stratify,
|
|
||||||
random_state=RANDOM_SEED,
|
|
||||||
)
|
|
||||||
weight1 = state.fit_kwargs["sample_weight"][
|
|
||||||
first
|
|
||||||
] # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
state.weight_val = concat(weight1, weight_val)
|
|
||||||
state.fit_kwargs[
|
|
||||||
"sample_weight"
|
|
||||||
] = concat( # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
weight1, weight_train
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
X_train, X_val, y_train, y_val = train_test_split(
|
|
||||||
X_rest,
|
|
||||||
y_rest,
|
|
||||||
test_size=split_ratio,
|
|
||||||
stratify=stratify,
|
|
||||||
random_state=RANDOM_SEED,
|
|
||||||
)
|
|
||||||
X_train = concat(X_first, X_train)
|
X_train = concat(X_first, X_train)
|
||||||
y_train = (
|
y_train = (
|
||||||
concat(label_set, y_train)
|
concat(label_set, y_train)
|
||||||
|
@ -495,58 +686,34 @@ class GenericTask(Task):
|
||||||
else np.concatenate([label_set, y_val])
|
else np.concatenate([label_set, y_val])
|
||||||
)
|
)
|
||||||
elif self.is_regression():
|
elif self.is_regression():
|
||||||
if (
|
X_train, X_val, y_train, y_val = self._train_test_split(
|
||||||
"sample_weight" in state.fit_kwargs
|
state, X_train_all, y_train_all, split_ratio=split_ratio
|
||||||
): # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
)
|
||||||
(
|
|
||||||
X_train,
|
|
||||||
X_val,
|
|
||||||
y_train,
|
|
||||||
y_val,
|
|
||||||
state.fit_kwargs[
|
|
||||||
"sample_weight"
|
|
||||||
], # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
state.weight_val,
|
|
||||||
) = train_test_split(
|
|
||||||
X_train_all,
|
|
||||||
y_train_all,
|
|
||||||
state.fit_kwargs[
|
|
||||||
"sample_weight"
|
|
||||||
], # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
|
||||||
test_size=split_ratio,
|
|
||||||
random_state=RANDOM_SEED,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
X_train, X_val, y_train, y_val = train_test_split(
|
|
||||||
X_train_all,
|
|
||||||
y_train_all,
|
|
||||||
test_size=split_ratio,
|
|
||||||
random_state=RANDOM_SEED,
|
|
||||||
)
|
|
||||||
state.data_size = X_train.shape
|
state.data_size = X_train.shape
|
||||||
state.X_train, state.y_train = X_train, y_train
|
state.X_train, state.y_train = X_train, y_train
|
||||||
state.X_val, state.y_val = X_val, y_val
|
state.X_val, state.y_val = X_val, y_val
|
||||||
state.X_train_all = X_train_all
|
state.X_train_all = X_train_all
|
||||||
state.y_train_all = y_train_all
|
state.y_train_all = y_train_all
|
||||||
|
y_train_all_size = y_train_all.size
|
||||||
if eval_method == "holdout":
|
if eval_method == "holdout":
|
||||||
state.kf = None
|
state.kf = None
|
||||||
return
|
return
|
||||||
if split_type == "group":
|
if split_type == "group":
|
||||||
# logger.info("Using GroupKFold")
|
# logger.info("Using GroupKFold")
|
||||||
assert (
|
assert (
|
||||||
len(state.groups_all) == y_train_all.size
|
len(state.groups_all) == y_train_all_size
|
||||||
), "the length of groups must match the number of examples"
|
), "the length of groups must match the number of examples"
|
||||||
assert (
|
assert (
|
||||||
len(np.unique(state.groups_all)) >= n_splits
|
len_labels(state.groups_all) >= n_splits
|
||||||
), "the number of groups must be equal or larger than n_splits"
|
), "the number of groups must be equal or larger than n_splits"
|
||||||
state.kf = GroupKFold(n_splits)
|
state.kf = GroupKFold(n_splits)
|
||||||
elif split_type == "stratified":
|
elif split_type == "stratified":
|
||||||
# logger.info("Using StratifiedKFold")
|
# logger.info("Using StratifiedKFold")
|
||||||
assert y_train_all.size >= n_splits, (
|
assert y_train_all_size >= n_splits, (
|
||||||
f"{n_splits}-fold cross validation"
|
f"{n_splits}-fold cross validation"
|
||||||
f" requires input data with at least {n_splits} examples."
|
f" requires input data with at least {n_splits} examples."
|
||||||
)
|
)
|
||||||
assert y_train_all.size >= 2 * n_splits, (
|
assert y_train_all_size >= 2 * n_splits, (
|
||||||
f"{n_splits}-fold cross validation with metric=r2 "
|
f"{n_splits}-fold cross validation with metric=r2 "
|
||||||
f"requires input data with at least {n_splits*2} examples."
|
f"requires input data with at least {n_splits*2} examples."
|
||||||
)
|
)
|
||||||
|
@ -559,8 +726,8 @@ class GenericTask(Task):
|
||||||
period = state.fit_kwargs[
|
period = state.fit_kwargs[
|
||||||
"period"
|
"period"
|
||||||
] # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
] # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
|
||||||
if period * (n_splits + 1) > y_train_all.size:
|
if period * (n_splits + 1) > y_train_all_size:
|
||||||
n_splits = int(y_train_all.size / period - 1)
|
n_splits = int(y_train_all_size / period - 1)
|
||||||
assert n_splits >= 2, (
|
assert n_splits >= 2, (
|
||||||
f"cross validation for forecasting period={period}"
|
f"cross validation for forecasting period={period}"
|
||||||
f" requires input data with at least {3 * period} examples."
|
f" requires input data with at least {3 * period} examples."
|
||||||
|
@ -568,7 +735,9 @@ class GenericTask(Task):
|
||||||
logger.info(f"Using nsplits={n_splits} due to data size limit.")
|
logger.info(f"Using nsplits={n_splits} due to data size limit.")
|
||||||
state.kf = TimeSeriesSplit(n_splits=n_splits, test_size=period)
|
state.kf = TimeSeriesSplit(n_splits=n_splits, test_size=period)
|
||||||
elif self.is_ts_forecastpanel():
|
elif self.is_ts_forecastpanel():
|
||||||
n_groups = X_train.groupby(state.fit_kwargs.get("group_ids")).ngroups
|
n_groups = len(
|
||||||
|
X_train.groupby(state.fit_kwargs.get("group_ids")).size()
|
||||||
|
)
|
||||||
period = state.fit_kwargs.get("period")
|
period = state.fit_kwargs.get("period")
|
||||||
state.kf = TimeSeriesSplit(
|
state.kf = TimeSeriesSplit(
|
||||||
n_splits=n_splits, test_size=period * n_groups
|
n_splits=n_splits, test_size=period * n_groups
|
||||||
|
@ -595,7 +764,7 @@ class GenericTask(Task):
|
||||||
groups=None,
|
groups=None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if self.name == "classification":
|
if self.name == "classification":
|
||||||
self.name = get_classification_objective(len(np.unique(y_train_all)))
|
self.name = get_classification_objective(len_labels(y_train_all))
|
||||||
if not isinstance(split_type, str):
|
if not isinstance(split_type, str):
|
||||||
assert hasattr(split_type, "split") and hasattr(
|
assert hasattr(split_type, "split") and hasattr(
|
||||||
split_type, "get_n_splits"
|
split_type, "get_n_splits"
|
||||||
|
@ -661,6 +830,8 @@ class GenericTask(Task):
|
||||||
)
|
)
|
||||||
elif isinstance(X, int):
|
elif isinstance(X, int):
|
||||||
return X
|
return X
|
||||||
|
elif isinstance(X, psDataFrame):
|
||||||
|
return X
|
||||||
elif issparse(X):
|
elif issparse(X):
|
||||||
X = X.tocsr()
|
X = X.tocsr()
|
||||||
if self.is_ts_forecast():
|
if self.is_ts_forecast():
|
||||||
|
@ -695,60 +866,87 @@ class GenericTask(Task):
|
||||||
train_time = pred_time = 0
|
train_time = pred_time = 0
|
||||||
total_fold_num = 0
|
total_fold_num = 0
|
||||||
n = kf.get_n_splits()
|
n = kf.get_n_splits()
|
||||||
X_train_split, y_train_split = X_train_all, y_train_all
|
rng = np.random.RandomState(2020)
|
||||||
|
budget_per_train = budget and budget / n
|
||||||
|
groups = None
|
||||||
if self.is_classification():
|
if self.is_classification():
|
||||||
labels = np.unique(y_train_all)
|
labels = _, labels = len_labels(y_train_all, return_labels=True)
|
||||||
else:
|
else:
|
||||||
labels = fit_kwargs.get(
|
labels = fit_kwargs.get(
|
||||||
"label_list"
|
"label_list"
|
||||||
) # pass the label list on to compute the evaluation metric
|
) # pass the label list on to compute the evaluation metric
|
||||||
groups = None
|
|
||||||
shuffle = getattr(kf, "shuffle", not self.is_ts_forecast())
|
|
||||||
if isinstance(kf, RepeatedStratifiedKFold):
|
|
||||||
kf = kf.split(X_train_split, y_train_split)
|
|
||||||
elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
|
|
||||||
groups = kf.groups
|
|
||||||
kf = kf.split(X_train_split, y_train_split, groups)
|
|
||||||
shuffle = False
|
|
||||||
elif isinstance(kf, TimeSeriesSplit):
|
|
||||||
kf = kf.split(X_train_split, y_train_split)
|
|
||||||
else:
|
|
||||||
kf = kf.split(X_train_split)
|
|
||||||
rng = np.random.RandomState(2020)
|
|
||||||
budget_per_train = budget and budget / n
|
|
||||||
if "sample_weight" in fit_kwargs:
|
if "sample_weight" in fit_kwargs:
|
||||||
weight = fit_kwargs["sample_weight"]
|
weight = fit_kwargs["sample_weight"]
|
||||||
weight_val = None
|
weight_val = None
|
||||||
else:
|
else:
|
||||||
weight = weight_val = None
|
weight = weight_val = None
|
||||||
|
|
||||||
|
is_spark_dataframe = isinstance(X_train_all, (psDataFrame, psSeries))
|
||||||
|
if is_spark_dataframe:
|
||||||
|
dataframe = X_train_all.join(y_train_all)
|
||||||
|
if weight is not None:
|
||||||
|
dataframe = dataframe.join(weight)
|
||||||
|
if isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
|
||||||
|
groups = kf.groups
|
||||||
|
dataframe = dataframe.join(groups)
|
||||||
|
kf = spark_kFold(
|
||||||
|
dataframe, nFolds=n, foldCol=groups.name if groups is not None else ""
|
||||||
|
)
|
||||||
|
shuffle = False
|
||||||
|
else:
|
||||||
|
X_train_split, y_train_split = X_train_all, y_train_all
|
||||||
|
shuffle = getattr(kf, "shuffle", not self.is_ts_forecast())
|
||||||
|
if isinstance(kf, RepeatedStratifiedKFold):
|
||||||
|
kf = kf.split(X_train_split, y_train_split)
|
||||||
|
elif isinstance(kf, (GroupKFold, StratifiedGroupKFold)):
|
||||||
|
groups = kf.groups
|
||||||
|
kf = kf.split(X_train_split, y_train_split, groups)
|
||||||
|
shuffle = False
|
||||||
|
elif isinstance(kf, TimeSeriesSplit):
|
||||||
|
kf = kf.split(X_train_split, y_train_split)
|
||||||
|
else:
|
||||||
|
kf = kf.split(X_train_split)
|
||||||
|
|
||||||
for train_index, val_index in kf:
|
for train_index, val_index in kf:
|
||||||
if shuffle:
|
if shuffle:
|
||||||
train_index = rng.permutation(train_index)
|
train_index = rng.permutation(train_index)
|
||||||
if isinstance(X_train_all, pd.DataFrame):
|
if is_spark_dataframe:
|
||||||
|
# cache data to increase compute speed
|
||||||
|
X_train = train_index.spark.cache()
|
||||||
|
X_val = val_index.spark.cache()
|
||||||
|
y_train = X_train.pop(y_train_all.name)
|
||||||
|
y_val = X_val.pop(y_train_all.name)
|
||||||
|
if weight is not None:
|
||||||
|
weight_val = X_val.pop(weight.name)
|
||||||
|
fit_kwargs["sample_weight"] = X_train.pop(weight.name)
|
||||||
|
groups_val = None
|
||||||
|
elif isinstance(X_train_all, pd.DataFrame):
|
||||||
X_train = X_train_split.iloc[train_index]
|
X_train = X_train_split.iloc[train_index]
|
||||||
X_val = X_train_split.iloc[val_index]
|
X_val = X_train_split.iloc[val_index]
|
||||||
else:
|
else:
|
||||||
X_train, X_val = X_train_split[train_index], X_train_split[val_index]
|
X_train, X_val = X_train_split[train_index], X_train_split[val_index]
|
||||||
y_train, y_val = y_train_split[train_index], y_train_split[val_index]
|
if not is_spark_dataframe:
|
||||||
|
y_train, y_val = y_train_split[train_index], y_train_split[val_index]
|
||||||
|
if weight is not None:
|
||||||
|
fit_kwargs["sample_weight"], weight_val = (
|
||||||
|
weight[train_index],
|
||||||
|
weight[val_index],
|
||||||
|
)
|
||||||
|
if groups is not None:
|
||||||
|
fit_kwargs["groups"] = (
|
||||||
|
groups[train_index]
|
||||||
|
if isinstance(groups, np.ndarray)
|
||||||
|
else groups.iloc[train_index]
|
||||||
|
)
|
||||||
|
groups_val = (
|
||||||
|
groups[val_index]
|
||||||
|
if isinstance(groups, np.ndarray)
|
||||||
|
else groups.iloc[val_index]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
groups_val = None
|
||||||
|
|
||||||
estimator.cleanup()
|
estimator.cleanup()
|
||||||
if weight is not None:
|
|
||||||
fit_kwargs["sample_weight"], weight_val = (
|
|
||||||
weight[train_index],
|
|
||||||
weight[val_index],
|
|
||||||
)
|
|
||||||
if groups is not None:
|
|
||||||
fit_kwargs["groups"] = (
|
|
||||||
groups[train_index]
|
|
||||||
if isinstance(groups, np.ndarray)
|
|
||||||
else groups.iloc[train_index]
|
|
||||||
)
|
|
||||||
groups_val = (
|
|
||||||
groups[val_index]
|
|
||||||
if isinstance(groups, np.ndarray)
|
|
||||||
else groups.iloc[val_index]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
groups_val = None
|
|
||||||
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
|
val_loss_i, metric_i, train_time_i, pred_time_i = get_val_loss(
|
||||||
config,
|
config,
|
||||||
estimator,
|
estimator,
|
||||||
|
@ -775,6 +973,9 @@ class GenericTask(Task):
|
||||||
log_metric_folds.append(metric_i)
|
log_metric_folds.append(metric_i)
|
||||||
train_time += train_time_i
|
train_time += train_time_i
|
||||||
pred_time += pred_time_i
|
pred_time += pred_time_i
|
||||||
|
if is_spark_dataframe:
|
||||||
|
X_train.spark.unpersist() # uncache data to free memory
|
||||||
|
X_val.spark.unpersist() # uncache data to free memory
|
||||||
if budget and time.time() - start_time >= budget:
|
if budget and time.time() - start_time >= budget:
|
||||||
break
|
break
|
||||||
val_loss, metric = cv_score_agg_func(val_loss_folds, log_metric_folds)
|
val_loss, metric = cv_score_agg_func(val_loss_folds, log_metric_folds)
|
||||||
|
@ -782,11 +983,44 @@ class GenericTask(Task):
|
||||||
pred_time /= n
|
pred_time /= n
|
||||||
return val_loss, metric, train_time, pred_time
|
return val_loss, metric, train_time, pred_time
|
||||||
|
|
||||||
def default_estimator_list(self, estimator_list: List[str]) -> List[str]:
|
def default_estimator_list(
|
||||||
|
self, estimator_list: List[str], is_spark_dataframe: bool = False
|
||||||
|
) -> List[str]:
|
||||||
if "auto" != estimator_list:
|
if "auto" != estimator_list:
|
||||||
|
n_estimators = len(estimator_list)
|
||||||
|
if is_spark_dataframe:
|
||||||
|
# For spark dataframe, only estimators ending with '_spark' are supported
|
||||||
|
estimator_list = [
|
||||||
|
est for est in estimator_list if est.endswith("_spark")
|
||||||
|
]
|
||||||
|
if len(estimator_list) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Spark dataframes only support estimator names ending with `_spark`. Non-supported "
|
||||||
|
"estimators are removed. No estimator is left."
|
||||||
|
)
|
||||||
|
elif n_estimators != len(estimator_list):
|
||||||
|
logger.warning(
|
||||||
|
"Spark dataframes only support estimator names ending with `_spark`. Non-supported "
|
||||||
|
"estimators are removed."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For non-spark dataframe, only estimators not ending with '_spark' are supported
|
||||||
|
estimator_list = [
|
||||||
|
est for est in estimator_list if not est.endswith("_spark")
|
||||||
|
]
|
||||||
|
if len(estimator_list) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported "
|
||||||
|
"estimators are removed. No estimator is left."
|
||||||
|
)
|
||||||
|
elif n_estimators != len(estimator_list):
|
||||||
|
logger.warning(
|
||||||
|
"Non-spark dataframes only support estimator names not ending with `_spark`. Non-supported "
|
||||||
|
"estimators are removed."
|
||||||
|
)
|
||||||
return estimator_list
|
return estimator_list
|
||||||
if self.is_rank():
|
if self.is_rank():
|
||||||
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth"]
|
estimator_list = ["lgbm", "xgboost", "xgb_limitdepth", "lgbm_spark"]
|
||||||
elif self.is_nlp():
|
elif self.is_nlp():
|
||||||
estimator_list = ["transformer"]
|
estimator_list = ["transformer"]
|
||||||
elif self.is_ts_forecastpanel():
|
elif self.is_ts_forecastpanel():
|
||||||
|
@ -802,6 +1036,7 @@ class GenericTask(Task):
|
||||||
"xgboost",
|
"xgboost",
|
||||||
"extra_tree",
|
"extra_tree",
|
||||||
"xgb_limitdepth",
|
"xgb_limitdepth",
|
||||||
|
"lgbm_spark",
|
||||||
]
|
]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
estimator_list = [
|
estimator_list = [
|
||||||
|
@ -810,6 +1045,7 @@ class GenericTask(Task):
|
||||||
"xgboost",
|
"xgboost",
|
||||||
"extra_tree",
|
"extra_tree",
|
||||||
"xgb_limitdepth",
|
"xgb_limitdepth",
|
||||||
|
"lgbm_spark",
|
||||||
]
|
]
|
||||||
if self.is_ts_forecast():
|
if self.is_ts_forecast():
|
||||||
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
|
# catboost is removed because it has a `name` parameter, making it incompatible with hcrystalball
|
||||||
|
@ -825,6 +1061,15 @@ class GenericTask(Task):
|
||||||
elif not self.is_regression():
|
elif not self.is_regression():
|
||||||
estimator_list += ["lrl1"]
|
estimator_list += ["lrl1"]
|
||||||
|
|
||||||
|
estimator_list = [
|
||||||
|
est
|
||||||
|
for est in estimator_list
|
||||||
|
if (
|
||||||
|
est.endswith("_spark")
|
||||||
|
if is_spark_dataframe
|
||||||
|
else not est.endswith("_spark")
|
||||||
|
)
|
||||||
|
]
|
||||||
return estimator_list
|
return estimator_list
|
||||||
|
|
||||||
def default_metric(self, metric: str) -> str:
|
def default_metric(self, metric: str) -> str:
|
||||||
|
|
|
@ -255,7 +255,9 @@ class Task(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def default_estimator_list(
|
def default_estimator_list(
|
||||||
self, estimator_list: Union[List[str], str] = "auto"
|
self,
|
||||||
|
estimator_list: Union[List[str], str] = "auto",
|
||||||
|
is_spark_dataframe: bool = False,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""Return the list of default estimators registered for this task type.
|
"""Return the list of default estimators registered for this task type.
|
||||||
|
|
||||||
|
@ -264,6 +266,7 @@ class Task(ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
estimator_list: Either 'auto' or a list of estimator names to be validated.
|
estimator_list: Either 'auto' or a list of estimator names to be validated.
|
||||||
|
is_spark_dataframe: True if the data is a spark dataframe.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of valid estimator names for this task type.
|
A list of valid estimator names for this task type.
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
from typing import Optional, Union, Tuple
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def len_labels(y: np.ndarray, return_labels=False) -> Union[int, Optional[np.ndarray]]:
|
||||||
|
"""Get the number of unique labels in y. The non-spark version of
|
||||||
|
flaml.automl.spark.utils.len_labels"""
|
||||||
|
labels = np.unique(y)
|
||||||
|
if return_labels:
|
||||||
|
return len(labels), labels
|
||||||
|
return len(labels)
|
||||||
|
|
||||||
|
|
||||||
|
def unique_value_first_index(y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Get the unique values and indices of a pandas series or numpy array.
|
||||||
|
The non-spark version of flaml.automl.spark.utils.unique_value_first_index"""
|
||||||
|
label_set, first_index = np.unique(y, return_index=True)
|
||||||
|
return label_set, first_index
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
import logging
|
import logging
|
||||||
|
@ -8,6 +9,24 @@ from flaml.automl.task.task import CLASSIFICATION, get_classification_objective
|
||||||
from flaml.automl.ml import get_estimator_class
|
from flaml.automl.ml import get_estimator_class
|
||||||
from flaml.version import __version__
|
from flaml.version import __version__
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flaml.automl.spark.utils import len_labels
|
||||||
|
except ImportError:
|
||||||
|
from flaml.automl.utils import len_labels
|
||||||
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from pyspark.pandas import DataFrame as psDataFrame, Series as psSeries
|
||||||
|
except ImportError:
|
||||||
|
ps = None
|
||||||
|
|
||||||
|
class psDataFrame:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class psSeries:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
LOCATION = pathlib.Path(__file__).parent.resolve()
|
LOCATION = pathlib.Path(__file__).parent.resolve()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
CONFIG_PREDICTORS = {}
|
CONFIG_PREDICTORS = {}
|
||||||
|
@ -29,12 +48,15 @@ def meta_feature(task, X_train, y_train, meta_feature_names):
|
||||||
elif each_feature_name == "NumberOfFeatures":
|
elif each_feature_name == "NumberOfFeatures":
|
||||||
this_feature.append(n_feat)
|
this_feature.append(n_feat)
|
||||||
elif each_feature_name == "NumberOfClasses":
|
elif each_feature_name == "NumberOfClasses":
|
||||||
this_feature.append(len(np.unique(y_train)) if is_classification else 0)
|
this_feature.append(len_labels(y_train) if is_classification else 0)
|
||||||
elif each_feature_name == "PercentageOfNumericFeatures":
|
elif each_feature_name == "PercentageOfNumericFeatures":
|
||||||
try:
|
try:
|
||||||
# this is feature is only supported for dataframe
|
# this feature is only supported for dataframe
|
||||||
this_feature.append(
|
this_feature.append(
|
||||||
X_train.select_dtypes(include=np.number).shape[1] / n_feat
|
X_train.select_dtypes(
|
||||||
|
include=[np.number, "float", "int", "long"]
|
||||||
|
).shape[1]
|
||||||
|
/ n_feat
|
||||||
)
|
)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# 'numpy.ndarray' object has no attribute 'select_dtypes'
|
# 'numpy.ndarray' object has no attribute 'select_dtypes'
|
||||||
|
@ -78,7 +100,7 @@ def suggest_config(
|
||||||
`FLAML_sample_size` is removed from the configs.
|
`FLAML_sample_size` is removed from the configs.
|
||||||
"""
|
"""
|
||||||
task = (
|
task = (
|
||||||
get_classification_objective(len(np.unique(y)))
|
get_classification_objective(len_labels(y))
|
||||||
if task == "classification" and y is not None
|
if task == "classification" and y is not None
|
||||||
else task
|
else task
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,8 +10,9 @@ logger = logging.getLogger(__name__)
|
||||||
logger_formatter = logging.Formatter(
|
logger_formatter = logging.Formatter(
|
||||||
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
|
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
|
||||||
)
|
)
|
||||||
|
logger.propagate = False
|
||||||
try:
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
import pyspark
|
import pyspark
|
||||||
from pyspark.sql import SparkSession
|
from pyspark.sql import SparkSession
|
||||||
from pyspark.util import VersionUtils
|
from pyspark.util import VersionUtils
|
||||||
|
|
|
@ -248,6 +248,7 @@ def run(
|
||||||
log_file_name: Optional[str] = None,
|
log_file_name: Optional[str] = None,
|
||||||
lexico_objectives: Optional[dict] = None,
|
lexico_objectives: Optional[dict] = None,
|
||||||
force_cancel: Optional[bool] = False,
|
force_cancel: Optional[bool] = False,
|
||||||
|
n_concurrent_trials: Optional[int] = 0,
|
||||||
**ray_args,
|
**ray_args,
|
||||||
):
|
):
|
||||||
"""The trigger for HPO.
|
"""The trigger for HPO.
|
||||||
|
@ -437,6 +438,14 @@ def run(
|
||||||
"targets": {"error_rate": 0.0},
|
"targets": {"error_rate": 0.0},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
force_cancel: boolean, default=False | Whether to forcely cancel the PySpark job if overtime.
|
||||||
|
n_concurrent_trials: int, default=0 | The number of concurrent trials when perform hyperparameter
|
||||||
|
tuning with Spark. Only valid when use_spark=True and spark is required:
|
||||||
|
`pip install flaml[spark]`. Please check
|
||||||
|
[here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
|
||||||
|
for more details about installing Spark. When tune.run() is called from AutoML, it will be
|
||||||
|
overwritten by the value of `n_concurrent_trials` in AutoML. When <= 0, the concurrent trials
|
||||||
|
will be set to the number of executors.
|
||||||
**ray_args: keyword arguments to pass to ray.tune.run().
|
**ray_args: keyword arguments to pass to ray.tune.run().
|
||||||
Only valid when use_ray=True.
|
Only valid when use_ray=True.
|
||||||
"""
|
"""
|
||||||
|
@ -674,18 +683,30 @@ def run(
|
||||||
is not an instance of `ConcurrencyLimiter`.
|
is not an instance of `ConcurrencyLimiter`.
|
||||||
|
|
||||||
The final number of concurrent trials is the minimum of `max_concurrent` and
|
The final number of concurrent trials is the minimum of `max_concurrent` and
|
||||||
`num_executors`.
|
`num_executors` if `n_concurrent_trials<=0` (default, automl cases), otherwise the
|
||||||
|
minimum of `max_concurrent` and `n_concurrent_trials` (tuning cases).
|
||||||
"""
|
"""
|
||||||
num_executors = max(num_executors, int(os.getenv("FLAML_MAX_CONCURRENT", 1)), 1)
|
|
||||||
time_start = time.time()
|
time_start = time.time()
|
||||||
|
try:
|
||||||
|
FLAML_MAX_CONCURRENT = int(os.getenv("FLAML_MAX_CONCURRENT", 0))
|
||||||
|
num_executors = max(num_executors, FLAML_MAX_CONCURRENT, 1)
|
||||||
|
except ValueError:
|
||||||
|
FLAML_MAX_CONCURRENT = 0
|
||||||
|
max_spark_parallelism = (
|
||||||
|
min(spark.sparkContext.defaultParallelism, FLAML_MAX_CONCURRENT)
|
||||||
|
if FLAML_MAX_CONCURRENT > 0
|
||||||
|
else spark.sparkContext.defaultParallelism
|
||||||
|
)
|
||||||
if scheduler:
|
if scheduler:
|
||||||
scheduler.set_search_properties(metric=metric, mode=mode)
|
scheduler.set_search_properties(metric=metric, mode=mode)
|
||||||
if isinstance(search_alg, ConcurrencyLimiter):
|
if isinstance(search_alg, ConcurrencyLimiter):
|
||||||
max_concurrent = max(1, search_alg.max_concurrent)
|
max_concurrent = max(1, search_alg.max_concurrent)
|
||||||
else:
|
else:
|
||||||
max_concurrent = max(1, int(os.getenv("FLAML_MAX_CONCURRENT", 1)))
|
max_concurrent = max(1, max_spark_parallelism)
|
||||||
|
n_concurrent_trials = min(
|
||||||
n_concurrent_trials = min(num_executors, max_concurrent)
|
n_concurrent_trials if n_concurrent_trials > 0 else num_executors,
|
||||||
|
max_concurrent,
|
||||||
|
)
|
||||||
with parallel_backend("spark"):
|
with parallel_backend("spark"):
|
||||||
with Parallel(
|
with Parallel(
|
||||||
n_jobs=n_concurrent_trials, verbose=max(0, (verbose - 1) * 50)
|
n_jobs=n_concurrent_trials, verbose=max(0, (verbose - 1) * 50)
|
||||||
|
|
|
@ -0,0 +1,831 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"attachments": {},
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# AutoML with FLAML Library for synapseML models and spark dataframes\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"## 1. Introduction\n",
|
||||||
|
"\n",
|
||||||
|
"FLAML is a Python library (https://github.com/microsoft/FLAML) designed to automatically produce accurate machine learning models \n",
|
||||||
|
"with low computational cost. It is fast and economical. The simple and lightweight design makes it easy \n",
|
||||||
|
"to use and extend, such as adding new learners. FLAML can \n",
|
||||||
|
"- serve as an economical AutoML engine,\n",
|
||||||
|
"- be used as a fast hyperparameter tuning tool, or \n",
|
||||||
|
"- be embedded in self-tuning software that requires low latency & resource in repetitive\n",
|
||||||
|
" tuning tasks.\n",
|
||||||
|
"\n",
|
||||||
|
"In this notebook, we demonstrate how to use FLAML library to do AutoML for synapseML models and spark dataframes. We also compare the results between FLAML AutoML and default SynapseML. \n",
|
||||||
|
"In this example, we use LightGBM to build a classification model in order to predict bankruptcy.\n",
|
||||||
|
"\n",
|
||||||
|
"Since the dataset is unbalanced, `AUC` is a better metric than `Accuracy`. FLAML (1 min of training) achieved AUC **0.79**, the default SynapseML model only got AUC **0.64**. \n",
|
||||||
|
"\n",
|
||||||
|
"FLAML requires `Python>=3.7`. To run this notebook example, please install flaml with the `synapse` option:\n",
|
||||||
|
"```bash\n",
|
||||||
|
"pip install flaml[synapse]>=1.1.3; \n",
|
||||||
|
"```\n",
|
||||||
|
" "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# %pip install \"flaml[synapse]>=1.1.3\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Load data and preprocess"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
":: loading settings :: url = jar:file:/datadrive/spark/spark33/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Ivy Default Cache set to: /home/lijiang1/.ivy2/cache\n",
|
||||||
|
"The jars for the packages stored in: /home/lijiang1/.ivy2/jars\n",
|
||||||
|
"com.microsoft.azure#synapseml_2.12 added as a dependency\n",
|
||||||
|
"org.apache.hadoop#hadoop-azure added as a dependency\n",
|
||||||
|
"com.microsoft.azure#azure-storage added as a dependency\n",
|
||||||
|
":: resolving dependencies :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb;1.0\n",
|
||||||
|
"\tconfs: [default]\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-core_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound org.scalactic#scalactic_2.12;3.2.14 in local-m2-cache\n",
|
||||||
|
"\tfound org.scala-lang#scala-reflect;2.12.15 in central\n",
|
||||||
|
"\tfound io.spray#spray-json_2.12;1.3.5 in central\n",
|
||||||
|
"\tfound com.jcraft#jsch;0.1.54 in central\n",
|
||||||
|
"\tfound org.apache.httpcomponents.client5#httpclient5;5.1.3 in central\n",
|
||||||
|
"\tfound org.apache.httpcomponents.core5#httpcore5;5.1.3 in central\n",
|
||||||
|
"\tfound org.apache.httpcomponents.core5#httpcore5-h2;5.1.3 in central\n",
|
||||||
|
"\tfound org.slf4j#slf4j-api;1.7.25 in local-m2-cache\n",
|
||||||
|
"\tfound commons-codec#commons-codec;1.15 in local-m2-cache\n",
|
||||||
|
"\tfound org.apache.httpcomponents#httpmime;4.5.13 in local-m2-cache\n",
|
||||||
|
"\tfound org.apache.httpcomponents#httpclient;4.5.13 in local-m2-cache\n",
|
||||||
|
"\tfound org.apache.httpcomponents#httpcore;4.4.13 in central\n",
|
||||||
|
"\tfound commons-logging#commons-logging;1.2 in central\n",
|
||||||
|
"\tfound com.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 in central\n",
|
||||||
|
"\tfound com.chuusai#shapeless_2.12;2.3.2 in central\n",
|
||||||
|
"\tfound org.typelevel#macro-compat_2.12;1.1.1 in central\n",
|
||||||
|
"\tfound org.apache.spark#spark-avro_2.12;3.2.0 in central\n",
|
||||||
|
"\tfound org.tukaani#xz;1.8 in central\n",
|
||||||
|
"\tfound org.spark-project.spark#unused;1.0.0 in central\n",
|
||||||
|
"\tfound org.testng#testng;6.8.8 in central\n",
|
||||||
|
"\tfound org.beanshell#bsh;2.0b4 in central\n",
|
||||||
|
"\tfound com.beust#jcommander;1.27 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-opencv_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound org.openpnp#opencv;3.2.0-1 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#onnx-protobuf_2.12;0.9.1 in central\n",
|
||||||
|
"\tfound com.microsoft.cntk#cntk;2.4 in central\n",
|
||||||
|
"\tfound com.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-cognitive_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound com.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-vw_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound com.github.vowpalwabbit#vw-jni;8.9.1 in central\n",
|
||||||
|
"\tfound com.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 in central\n",
|
||||||
|
"\tfound com.microsoft.ml.lightgbm#lightgbmlib;3.2.110 in central\n",
|
||||||
|
"\tfound org.apache.hadoop#hadoop-azure;3.3.1 in central\n",
|
||||||
|
"\tfound org.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 in local-m2-cache\n",
|
||||||
|
"\tfound org.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 in central\n",
|
||||||
|
"\tfound org.eclipse.jetty#jetty-util;9.4.40.v20210413 in central\n",
|
||||||
|
"\tfound org.codehaus.jackson#jackson-mapper-asl;1.9.13 in local-m2-cache\n",
|
||||||
|
"\tfound org.codehaus.jackson#jackson-core-asl;1.9.13 in local-m2-cache\n",
|
||||||
|
"\tfound org.wildfly.openssl#wildfly-openssl;1.0.7.Final in local-m2-cache\n",
|
||||||
|
"\tfound com.microsoft.azure#azure-storage;8.6.6 in central\n",
|
||||||
|
"\tfound com.fasterxml.jackson.core#jackson-core;2.9.4 in central\n",
|
||||||
|
"\tfound org.apache.commons#commons-lang3;3.4 in local-m2-cache\n",
|
||||||
|
"\tfound com.microsoft.azure#azure-keyvault-core;1.2.4 in central\n",
|
||||||
|
"\tfound com.google.guava#guava;24.1.1-jre in central\n",
|
||||||
|
"\tfound com.google.code.findbugs#jsr305;1.3.9 in central\n",
|
||||||
|
"\tfound org.checkerframework#checker-compat-qual;2.0.0 in central\n",
|
||||||
|
"\tfound com.google.errorprone#error_prone_annotations;2.1.3 in central\n",
|
||||||
|
"\tfound com.google.j2objc#j2objc-annotations;1.1 in central\n",
|
||||||
|
"\tfound org.codehaus.mojo#animal-sniffer-annotations;1.14 in central\n",
|
||||||
|
":: resolution report :: resolve 992ms :: artifacts dl 77ms\n",
|
||||||
|
"\t:: modules in use:\n",
|
||||||
|
"\tcom.beust#jcommander;1.27 from central in [default]\n",
|
||||||
|
"\tcom.chuusai#shapeless_2.12;2.3.2 from central in [default]\n",
|
||||||
|
"\tcom.fasterxml.jackson.core#jackson-core;2.9.4 from central in [default]\n",
|
||||||
|
"\tcom.github.vowpalwabbit#vw-jni;8.9.1 from central in [default]\n",
|
||||||
|
"\tcom.google.code.findbugs#jsr305;1.3.9 from central in [default]\n",
|
||||||
|
"\tcom.google.errorprone#error_prone_annotations;2.1.3 from central in [default]\n",
|
||||||
|
"\tcom.google.guava#guava;24.1.1-jre from central in [default]\n",
|
||||||
|
"\tcom.google.j2objc#j2objc-annotations;1.1 from central in [default]\n",
|
||||||
|
"\tcom.jcraft#jsch;0.1.54 from central in [default]\n",
|
||||||
|
"\tcom.linkedin.isolation-forest#isolation-forest_3.2.0_2.12;2.0.8 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#azure-keyvault-core;1.2.4 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#azure-storage;8.6.6 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#onnx-protobuf_2.12;0.9.1 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-cognitive_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-core_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-deep-learning_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-lightgbm_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-opencv_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml-vw_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#synapseml_2.12;0.10.2 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.cntk#cntk;2.4 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.cognitiveservices.speech#client-jar-sdk;1.14.0 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.ml.lightgbm#lightgbmlib;3.2.110 from central in [default]\n",
|
||||||
|
"\tcom.microsoft.onnxruntime#onnxruntime_gpu;1.8.1 from central in [default]\n",
|
||||||
|
"\tcommons-codec#commons-codec;1.15 from local-m2-cache in [default]\n",
|
||||||
|
"\tcommons-logging#commons-logging;1.2 from central in [default]\n",
|
||||||
|
"\tio.spray#spray-json_2.12;1.3.5 from central in [default]\n",
|
||||||
|
"\torg.apache.commons#commons-lang3;3.4 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.apache.hadoop#hadoop-azure;3.3.1 from central in [default]\n",
|
||||||
|
"\torg.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents#httpclient;4.5.13 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents#httpcore;4.4.13 from central in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents#httpmime;4.5.13 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents.client5#httpclient5;5.1.3 from central in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents.core5#httpcore5;5.1.3 from central in [default]\n",
|
||||||
|
"\torg.apache.httpcomponents.core5#httpcore5-h2;5.1.3 from central in [default]\n",
|
||||||
|
"\torg.apache.spark#spark-avro_2.12;3.2.0 from central in [default]\n",
|
||||||
|
"\torg.beanshell#bsh;2.0b4 from central in [default]\n",
|
||||||
|
"\torg.checkerframework#checker-compat-qual;2.0.0 from central in [default]\n",
|
||||||
|
"\torg.codehaus.jackson#jackson-core-asl;1.9.13 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.codehaus.jackson#jackson-mapper-asl;1.9.13 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.codehaus.mojo#animal-sniffer-annotations;1.14 from central in [default]\n",
|
||||||
|
"\torg.eclipse.jetty#jetty-util;9.4.40.v20210413 from central in [default]\n",
|
||||||
|
"\torg.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 from central in [default]\n",
|
||||||
|
"\torg.openpnp#opencv;3.2.0-1 from central in [default]\n",
|
||||||
|
"\torg.scala-lang#scala-reflect;2.12.15 from central in [default]\n",
|
||||||
|
"\torg.scalactic#scalactic_2.12;3.2.14 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.slf4j#slf4j-api;1.7.25 from local-m2-cache in [default]\n",
|
||||||
|
"\torg.spark-project.spark#unused;1.0.0 from central in [default]\n",
|
||||||
|
"\torg.testng#testng;6.8.8 from central in [default]\n",
|
||||||
|
"\torg.tukaani#xz;1.8 from central in [default]\n",
|
||||||
|
"\torg.typelevel#macro-compat_2.12;1.1.1 from central in [default]\n",
|
||||||
|
"\torg.wildfly.openssl#wildfly-openssl;1.0.7.Final from local-m2-cache in [default]\n",
|
||||||
|
"\t:: evicted modules:\n",
|
||||||
|
"\tcommons-codec#commons-codec;1.11 by [commons-codec#commons-codec;1.15] in [default]\n",
|
||||||
|
"\tcom.microsoft.azure#azure-storage;7.0.1 by [com.microsoft.azure#azure-storage;8.6.6] in [default]\n",
|
||||||
|
"\torg.slf4j#slf4j-api;1.7.12 by [org.slf4j#slf4j-api;1.7.25] in [default]\n",
|
||||||
|
"\torg.apache.commons#commons-lang3;3.8.1 by [org.apache.commons#commons-lang3;3.4] in [default]\n",
|
||||||
|
"\t---------------------------------------------------------------------\n",
|
||||||
|
"\t| | modules || artifacts |\n",
|
||||||
|
"\t| conf | number| search|dwnlded|evicted|| number|dwnlded|\n",
|
||||||
|
"\t---------------------------------------------------------------------\n",
|
||||||
|
"\t| default | 57 | 0 | 0 | 4 || 53 | 0 |\n",
|
||||||
|
"\t---------------------------------------------------------------------\n",
|
||||||
|
":: retrieving :: org.apache.spark#spark-submit-parent-bfb2447b-61c5-4941-bf9b-0548472077eb\n",
|
||||||
|
"\tconfs: [default]\n",
|
||||||
|
"\t0 artifacts copied, 53 already retrieved (0kB/20ms)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"23/02/28 02:12:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Setting default log level to \"WARN\".\n",
|
||||||
|
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import pyspark\n",
|
||||||
|
"\n",
|
||||||
|
"spark = (\n",
|
||||||
|
" pyspark.sql.SparkSession.builder.appName(\"MyApp\")\n",
|
||||||
|
" .config(\n",
|
||||||
|
" \"spark.jars.packages\",\n",
|
||||||
|
" f\"com.microsoft.azure:synapseml_2.12:0.10.2,org.apache.hadoop:hadoop-azure:{pyspark.__version__},com.microsoft.azure:azure-storage:8.6.6\",\n",
|
||||||
|
" )\n",
|
||||||
|
" .config(\"spark.sql.debug.maxToStringFields\", \"100\")\n",
|
||||||
|
" .getOrCreate()\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"23/02/28 02:12:32 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-azure-file-system.properties,hadoop-metrics2.properties\n",
|
||||||
|
"records read: 6819\n",
|
||||||
|
"Schema: \n",
|
||||||
|
"root\n",
|
||||||
|
" |-- Bankrupt?: integer (nullable = true)\n",
|
||||||
|
" |-- ROA(C) before interest and depreciation before interest: double (nullable = true)\n",
|
||||||
|
" |-- ROA(A) before interest and % after tax: double (nullable = true)\n",
|
||||||
|
" |-- ROA(B) before interest and depreciation after tax: double (nullable = true)\n",
|
||||||
|
" |-- Operating Gross Margin: double (nullable = true)\n",
|
||||||
|
" |-- Realized Sales Gross Margin: double (nullable = true)\n",
|
||||||
|
" |-- Operating Profit Rate: double (nullable = true)\n",
|
||||||
|
" |-- Pre-tax net Interest Rate: double (nullable = true)\n",
|
||||||
|
" |-- After-tax net Interest Rate: double (nullable = true)\n",
|
||||||
|
" |-- Non-industry income and expenditure/revenue: double (nullable = true)\n",
|
||||||
|
" |-- Continuous interest rate (after tax): double (nullable = true)\n",
|
||||||
|
" |-- Operating Expense Rate: double (nullable = true)\n",
|
||||||
|
" |-- Research and development expense rate: double (nullable = true)\n",
|
||||||
|
" |-- Cash flow rate: double (nullable = true)\n",
|
||||||
|
" |-- Interest-bearing debt interest rate: double (nullable = true)\n",
|
||||||
|
" |-- Tax rate (A): double (nullable = true)\n",
|
||||||
|
" |-- Net Value Per Share (B): double (nullable = true)\n",
|
||||||
|
" |-- Net Value Per Share (A): double (nullable = true)\n",
|
||||||
|
" |-- Net Value Per Share (C): double (nullable = true)\n",
|
||||||
|
" |-- Persistent EPS in the Last Four Seasons: double (nullable = true)\n",
|
||||||
|
" |-- Cash Flow Per Share: double (nullable = true)\n",
|
||||||
|
" |-- Revenue Per Share (Yuan ??): double (nullable = true)\n",
|
||||||
|
" |-- Operating Profit Per Share (Yuan ??): double (nullable = true)\n",
|
||||||
|
" |-- Per Share Net profit before tax (Yuan ??): double (nullable = true)\n",
|
||||||
|
" |-- Realized Sales Gross Profit Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Operating Profit Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- After-tax Net Profit Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Regular Net Profit Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Continuous Net Profit Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Total Asset Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Net Value Growth Rate: double (nullable = true)\n",
|
||||||
|
" |-- Total Asset Return Growth Rate Ratio: double (nullable = true)\n",
|
||||||
|
" |-- Cash Reinvestment %: double (nullable = true)\n",
|
||||||
|
" |-- Current Ratio: double (nullable = true)\n",
|
||||||
|
" |-- Quick Ratio: double (nullable = true)\n",
|
||||||
|
" |-- Interest Expense Ratio: double (nullable = true)\n",
|
||||||
|
" |-- Total debt/Total net worth: double (nullable = true)\n",
|
||||||
|
" |-- Debt ratio %: double (nullable = true)\n",
|
||||||
|
" |-- Net worth/Assets: double (nullable = true)\n",
|
||||||
|
" |-- Long-term fund suitability ratio (A): double (nullable = true)\n",
|
||||||
|
" |-- Borrowing dependency: double (nullable = true)\n",
|
||||||
|
" |-- Contingent liabilities/Net worth: double (nullable = true)\n",
|
||||||
|
" |-- Operating profit/Paid-in capital: double (nullable = true)\n",
|
||||||
|
" |-- Net profit before tax/Paid-in capital: double (nullable = true)\n",
|
||||||
|
" |-- Inventory and accounts receivable/Net value: double (nullable = true)\n",
|
||||||
|
" |-- Total Asset Turnover: double (nullable = true)\n",
|
||||||
|
" |-- Accounts Receivable Turnover: double (nullable = true)\n",
|
||||||
|
" |-- Average Collection Days: double (nullable = true)\n",
|
||||||
|
" |-- Inventory Turnover Rate (times): double (nullable = true)\n",
|
||||||
|
" |-- Fixed Assets Turnover Frequency: double (nullable = true)\n",
|
||||||
|
" |-- Net Worth Turnover Rate (times): double (nullable = true)\n",
|
||||||
|
" |-- Revenue per person: double (nullable = true)\n",
|
||||||
|
" |-- Operating profit per person: double (nullable = true)\n",
|
||||||
|
" |-- Allocation rate per person: double (nullable = true)\n",
|
||||||
|
" |-- Working Capital to Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Quick Assets/Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Current Assets/Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Cash/Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Quick Assets/Current Liability: double (nullable = true)\n",
|
||||||
|
" |-- Cash/Current Liability: double (nullable = true)\n",
|
||||||
|
" |-- Current Liability to Assets: double (nullable = true)\n",
|
||||||
|
" |-- Operating Funds to Liability: double (nullable = true)\n",
|
||||||
|
" |-- Inventory/Working Capital: double (nullable = true)\n",
|
||||||
|
" |-- Inventory/Current Liability: double (nullable = true)\n",
|
||||||
|
" |-- Current Liabilities/Liability: double (nullable = true)\n",
|
||||||
|
" |-- Working Capital/Equity: double (nullable = true)\n",
|
||||||
|
" |-- Current Liabilities/Equity: double (nullable = true)\n",
|
||||||
|
" |-- Long-term Liability to Current Assets: double (nullable = true)\n",
|
||||||
|
" |-- Retained Earnings to Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Total income/Total expense: double (nullable = true)\n",
|
||||||
|
" |-- Total expense/Assets: double (nullable = true)\n",
|
||||||
|
" |-- Current Asset Turnover Rate: double (nullable = true)\n",
|
||||||
|
" |-- Quick Asset Turnover Rate: double (nullable = true)\n",
|
||||||
|
" |-- Working capitcal Turnover Rate: double (nullable = true)\n",
|
||||||
|
" |-- Cash Turnover Rate: double (nullable = true)\n",
|
||||||
|
" |-- Cash Flow to Sales: double (nullable = true)\n",
|
||||||
|
" |-- Fixed Assets to Assets: double (nullable = true)\n",
|
||||||
|
" |-- Current Liability to Liability: double (nullable = true)\n",
|
||||||
|
" |-- Current Liability to Equity: double (nullable = true)\n",
|
||||||
|
" |-- Equity to Long-term Liability: double (nullable = true)\n",
|
||||||
|
" |-- Cash Flow to Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Cash Flow to Liability: double (nullable = true)\n",
|
||||||
|
" |-- CFO to Assets: double (nullable = true)\n",
|
||||||
|
" |-- Cash Flow to Equity: double (nullable = true)\n",
|
||||||
|
" |-- Current Liability to Current Assets: double (nullable = true)\n",
|
||||||
|
" |-- Liability-Assets Flag: double (nullable = true)\n",
|
||||||
|
" |-- Net Income to Total Assets: double (nullable = true)\n",
|
||||||
|
" |-- Total assets to GNP price: double (nullable = true)\n",
|
||||||
|
" |-- No-credit Interval: double (nullable = true)\n",
|
||||||
|
" |-- Gross Profit to Sales: double (nullable = true)\n",
|
||||||
|
" |-- Net Income to Stockholder's Equity: double (nullable = true)\n",
|
||||||
|
" |-- Liability to Equity: double (nullable = true)\n",
|
||||||
|
" |-- Degree of Financial Leverage (DFL): double (nullable = true)\n",
|
||||||
|
" |-- Interest Coverage Ratio (Interest expense to EBIT): double (nullable = true)\n",
|
||||||
|
" |-- Net Income Flag: double (nullable = true)\n",
|
||||||
|
" |-- Equity to Liability: double (nullable = true)\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df = (\n",
|
||||||
|
" spark.read.format(\"csv\")\n",
|
||||||
|
" .option(\"header\", True)\n",
|
||||||
|
" .option(\"inferSchema\", True)\n",
|
||||||
|
" .load(\n",
|
||||||
|
" \"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv\"\n",
|
||||||
|
" )\n",
|
||||||
|
")\n",
|
||||||
|
"# print dataset size\n",
|
||||||
|
"print(\"records read: \" + str(df.count()))\n",
|
||||||
|
"print(\"Schema: \")\n",
|
||||||
|
"df.printSchema()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Split the dataset into train and test"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train, test = df.randomSplit([0.8, 0.2], seed=41)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Add featurizer to convert features to vector"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from pyspark.ml.feature import VectorAssembler\n",
|
||||||
|
"\n",
|
||||||
|
"feature_cols = df.columns[1:]\n",
|
||||||
|
"featurizer = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n",
|
||||||
|
"train_data = featurizer.transform(train)[\"Bankrupt?\", \"features\"]\n",
|
||||||
|
"test_data = featurizer.transform(test)[\"Bankrupt?\", \"features\"]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Default SynapseML LightGBM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"23/02/28 02:12:42 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n",
|
||||||
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from synapse.ml.lightgbm import LightGBMClassifier\n",
|
||||||
|
"\n",
|
||||||
|
"model = LightGBMClassifier(\n",
|
||||||
|
" objective=\"binary\", featuresCol=\"features\", labelCol=\"Bankrupt?\", isUnbalance=True\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"model = model.fit(train_data)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"#### Model Prediction"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"DataFrame[evaluation_type: string, confusion_matrix: matrix, accuracy: double, precision: double, recall: double, AUC: double]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[Stage 27:> (0 + 1) / 1]\r"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
||||||
|
"|evaluation_type| confusion_matrix| accuracy| precision| recall| AUC|\n",
|
||||||
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
||||||
|
"| Classification|1250.0 23.0 \\n3...|0.958997722095672|0.3611111111111111|0.29545454545454547|0.6386934942512319|\n",
|
||||||
|
"+---------------+--------------------+-----------------+------------------+-------------------+------------------+\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" \r"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def predict(model):\n",
|
||||||
|
" from synapse.ml.train import ComputeModelStatistics\n",
|
||||||
|
"\n",
|
||||||
|
" predictions = model.transform(test_data)\n",
|
||||||
|
" # predictions.limit(10).show()\n",
|
||||||
|
" \n",
|
||||||
|
" metrics = ComputeModelStatistics(\n",
|
||||||
|
" evaluationMetric=\"classification\",\n",
|
||||||
|
" labelCol=\"Bankrupt?\",\n",
|
||||||
|
" scoredLabelsCol=\"prediction\",\n",
|
||||||
|
" ).transform(predictions)\n",
|
||||||
|
" display(metrics)\n",
|
||||||
|
" return metrics\n",
|
||||||
|
"\n",
|
||||||
|
"default_metrics = predict(model)\n",
|
||||||
|
"default_metrics.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Run FLAML\n",
|
||||||
|
"In the FLAML automl run configuration, users can specify the task type, time budget, error metric, learner list, whether to subsample, resampling strategy type, and so on. All these arguments have default values which will be used if users do not provide them. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"''' import AutoML class from flaml package '''\n",
|
||||||
|
"from flaml import AutoML\n",
|
||||||
|
"from flaml.automl.spark.utils import to_pandas_on_spark\n",
|
||||||
|
"\n",
|
||||||
|
"automl = AutoML()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"settings = {\n",
|
||||||
|
" \"time_budget\": 30, # total running time in seconds\n",
|
||||||
|
" \"metric\": 'roc_auc',\n",
|
||||||
|
" \"estimator_list\": ['lgbm_spark'], # list of ML learners; we tune lightgbm in this example\n",
|
||||||
|
" \"task\": 'classification', # task type\n",
|
||||||
|
" \"log_file_name\": 'flaml_experiment.log', # flaml log file\n",
|
||||||
|
" \"seed\": 41, # random seed\n",
|
||||||
|
" \"force_cancel\": True, # force stop training once time_budget is used up\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Disable Arrow optimization to omit below warning:\n",
|
||||||
|
"```\n",
|
||||||
|
"/opt/spark/python/lib/pyspark.zip/pyspark/sql/pandas/conversion.py:87: UserWarning: toPandas attempted Arrow optimization because 'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, failed by the reason below:\n",
|
||||||
|
" Unsupported type in conversion to Arrow: VectorUDT\n",
|
||||||
|
"Attempting non-optimization as 'spark.sql.execution.arrow.pyspark.fallback.enabled' is set to true.\n",
|
||||||
|
" warnings.warn(msg)\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>index</th>\n",
|
||||||
|
" <th>Bankrupt?</th>\n",
|
||||||
|
" <th>features</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>[0.0828, 0.0693, 0.0884, 0.6468, 0.6468, 0.997...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>1</td>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>[0.1606, 0.1788, 0.1832, 0.5897, 0.5897, 0.998...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>2</td>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>[0.204, 0.2638, 0.2598, 0.4483, 0.4483, 0.9959...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>3</td>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>[0.217, 0.1881, 0.2451, 0.5992, 0.5992, 0.9962...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>4</td>\n",
|
||||||
|
" <td>0</td>\n",
|
||||||
|
" <td>[0.2314, 0.1628, 0.2068, 0.6001, 0.6001, 0.998...</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" index Bankrupt? features\n",
|
||||||
|
"0 0 0 [0.0828, 0.0693, 0.0884, 0.6468, 0.6468, 0.9971, 0.7958, 0.8078, 0.3047, 0.78, 0.0027, 0.0029, 0.428, 0.0, 0.0, 0.1273, 0.1273, 0.1273, 0.1872, 0.3127, 0.0038, 0.062, 0.1482, 0.022, 0.8478, 0.6893, 0.6893, 0.2176, 0.0, 0.0002, 0.2628, 0.291, 0.0039, 0.0025, 0.6306, 0.0137, 0.1776, 0.8224, 0.005, 0.3696, 0.0054, 0.062, 0.1473, 0.3986, 0.1109, 0.0003, 0.0182, 7150000000.0, 0.0003, 0.0302, 0.0025, 0.3763, 0.0009, 0.6971, 0.262, 0.3948, 0.0918, 0.0025, 0.0027, 0.1828, 0.242, 0.2766, 0.0039, 0.984, 0.7264, 0.3382, 0.0, 0.0, 0.0021, 1.0, 3650000000.0, 2500000000.0, 0.5939, 3060000000.0, 0.6714, 0.4836, 0.984, 0.3382, 0.1109, 0.0, 0.3666, 0.0, 0.1653, 0.072, 0.0, 0.0, 0.0, 0.6237, 0.6468, 0.7483, 0.2847, 0.0268, 0.5652, 1.0, 0.0199]\n",
|
||||||
|
"1 1 0 [0.1606, 0.1788, 0.1832, 0.5897, 0.5897, 0.9986, 0.7969, 0.8088, 0.3034, 0.781, 0.0003, 0.0002, 0.4434, 0.0002, 0.0, 0.1341, 0.1341, 0.1341, 0.1637, 0.2935, 0.0215, 0.0575, 0.1295, 0.0222, 0.848, 0.6894, 0.6894, 0.2176, 6700000000.0, 0.0003, 0.2646, 0.1561, 0.0075, 0.0016, 0.6306, 0.0275, 0.2228, 0.7772, 0.0061, 0.3952, 0.0054, 0.0574, 0.1285, 0.4264, 0.2579, 0.0218, 0.0003, 7550000000.0, 0.0029, 0.0569, 0.0184, 0.3689, 0.0009, 0.8013, 0.3721, 0.9357, 0.1842, 0.0028, 0.0042, 0.232, 0.2865, 0.2785, 0.0123, 1.0, 0.7403, 0.3506, 0.0, 0.811, 0.0019, 0.1083, 0.0001, 5310000000.0, 0.5939, 7880000000.0, 0.6715, 0.0499, 1.0, 0.3506, 0.1109, 0.463, 0.4385, 0.1781, 0.2476, 0.0388, 0.0, 0.5917, 4370000000.0, 0.6236, 0.5897, 0.8023, 0.2947, 0.0268, 0.5651, 1.0, 0.0151]\n",
|
||||||
|
"2 2 0 [0.204, 0.2638, 0.2598, 0.4483, 0.4483, 0.9959, 0.7937, 0.8063, 0.3034, 0.7782, 0.0007, 0.0004, 0.4511, 0.0003, 0.0, 0.1387, 0.1387, 0.1387, 0.1546, 0.263, 0.004, 0.0393, 0.0757, 0.0187, 0.8468, 0.6872, 0.6872, 0.2173, 0.0002, 0.0004, 0.2588, 0.1568, 0.0025, 0.0007, 0.6305, 0.04, 0.2419, 0.7581, 0.0048, 0.4073, 0.0054, 0.0394, 0.1165, 0.4142, 0.0315, 0.0009, 0.0074, 5310000000.0, 3030000000.0, 0.0195, 0.002, 0.3723, 0.0124, 0.6252, 0.1282, 0.3562, 0.0377, 0.0008, 0.0008, 0.2515, 0.3097, 0.2767, 0.0046, 1.0, 0.7042, 0.3617, 0.0, 0.8891, 0.0013, 0.0213, 0.0006, 0.0002, 0.5933, 0.0002, 0.6715, 0.5863, 1.0, 0.3617, 0.1109, 0.635, 0.4584, 0.3252, 0.3106, 0.1097, 0.0, 0.6816, 0.0003, 0.6221, 0.4483, 0.8117, 0.3038, 0.0268, 0.5651, 1.0, 0.0136]\n",
|
||||||
|
"3 3 0 [0.217, 0.1881, 0.2451, 0.5992, 0.5992, 0.9962, 0.794, 0.8061, 0.3034, 0.7781, 0.0029, 0.0038, 0.4555, 0.0003, 0.0, 0.1277, 0.1277, 0.1277, 0.1387, 0.271, 0.0049, 0.0319, 0.0091, 0.022, 0.848, 0.6893, 0.6893, 0.2176, 9790000000.0, 0.0011, 0.2629, 0.0, 0.004, 0.004, 0.6305, 0.2222, 0.286, 0.714, 0.0052, 0.6137, 0.0054, 0.0608, 0.1361, 0.407, 0.039, 0.0008, 0.0078, 0.0002, 0.0006, 0.1497, 0.0091, 0.3072, 0.0015, 0.6671, 0.6679, 0.656, 0.6709, 0.004, 0.012, 0.2966, 0.3228, 0.2769, 0.0003, 1.0, 0.6453, 0.523, 0.0, 0.8015, 0.002, 0.112, 0.0008, 0.0008, 0.5937, 0.0022, 0.6723, 0.022, 1.0, 0.523, 0.1109, 0.9353, 0.4857, 0.402, 1.0, 0.0707, 0.0, 0.6196, 0.0011, 0.6236, 0.5992, 0.6346, 0.4359, 0.0268, 0.565, 1.0, 0.0108]\n",
|
||||||
|
"4 4 0 [0.2314, 0.1628, 0.2068, 0.6001, 0.6001, 0.9988, 0.796, 0.8078, 0.3015, 0.7801, 0.0003, 0.0002, 0.458, 0.0005, 0.0, 0.1351, 0.1351, 0.1351, 0.1599, 0.315, 0.0085, 0.088, 0.1271, 0.0223, 0.8481, 0.6894, 0.6894, 0.2176, 3860000000.0, 0.0003, 0.2633, 0.363, 0.011, 0.0072, 0.6306, 0.0214, 0.2081, 0.7919, 0.0053, 0.3832, 0.0123, 0.088, 0.1261, 0.3996, 0.0885, 0.0008, 0.0075, 0.0005, 0.0003, 0.025, 0.0108, 0.3855, 0.0044, 0.8522, 0.8464, 0.8194, 0.0331, 0.0111, 0.0013, 0.1393, 0.3341, 0.277, 0.0003, 0.637, 0.7459, 0.3384, 0.0024, 0.8278, 0.002, 0.184, 0.0003, 0.0003, 0.594, 3320000000.0, 0.6715, 0.1798, 0.637, 0.3384, 0.1171, 0.587, 0.4524, 0.521, 0.2972, 0.0265, 0.0, 0.5269, 0.0003, 0.6241, 0.6001, 0.7985, 0.2903, 0.0268, 0.5651, 1.0, 0.0164]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 11,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df = to_pandas_on_spark(to_pandas_on_spark(train_data).to_spark(index_col=\"index\"))\n",
|
||||||
|
"\n",
|
||||||
|
"df.head()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[flaml.automl.automl: 02-28 02:12:59] {2922} INFO - task = classification\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:00] {2924} INFO - Data split method: stratified\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:00] {2927} INFO - Evaluation method: cv\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas Series is expected to be small.\n",
|
||||||
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:01] {3054} INFO - Minimizing error metric: 1-roc_auc\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:01] {3209} INFO - List of ML learners in AutoML Run: ['lgbm_spark']\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:01] {3539} INFO - iteration 0, current learner lgbm_spark\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_numpy` loads all data into the driver's memory. It should only be used if the resulting NumPy ndarray is expected to be small.\n",
|
||||||
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n",
|
||||||
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.\n",
|
||||||
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
||||||
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: `to_numpy` loads all data into the driver's memory. It should only be used if the resulting NumPy ndarray is expected to be small.\n",
|
||||||
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:48] {3677} INFO - Estimated sufficient time budget=464999s. Estimated necessary time budget=465s.\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:48] {3724} INFO - at 48.5s,\testimator lgbm_spark's best error=0.0871,\tbest estimator lgbm_spark's best error=0.0871\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/datadrive/spark/spark33/python/pyspark/pandas/utils.py:975: PandasAPIOnSparkAdviceWarning: If `index_col` is not specified for `to_spark`, the existing index is lost when converting to Spark DataFrame.\n",
|
||||||
|
" warnings.warn(message, PandasAPIOnSparkAdviceWarning)\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
||||||
|
"[LightGBM] [Warning] Find whitespaces in feature_names, replace with underlines\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:54] {3988} INFO - retrain lgbm_spark for 6.2s\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:54] {3995} INFO - retrained model: LightGBMClassifier_a2177c5be001\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:54] {3239} INFO - fit succeeded\n",
|
||||||
|
"[flaml.automl.automl: 02-28 02:13:54] {3240} INFO - Time taken to find the best model: 48.4579541683197\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"'''The main flaml automl API'''\n",
|
||||||
|
"automl.fit(dataframe=df, label='Bankrupt?', labelCol=\"Bankrupt?\", isUnbalance=True, **settings)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Best model and metric"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Best hyperparmeter config: {'numIterations': 4, 'numLeaves': 4, 'minDataInLeaf': 20, 'learningRate': 0.09999999999999995, 'log_max_bin': 8, 'featureFraction': 1.0, 'lambdaL1': 0.0009765625, 'lambdaL2': 1.0}\n",
|
||||||
|
"Best roc_auc on validation data: 0.9129\n",
|
||||||
|
"Training duration of best run: 6.237 s\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"''' retrieve best config'''\n",
|
||||||
|
"print('Best hyperparmeter config:', automl.best_config)\n",
|
||||||
|
"print('Best roc_auc on validation data: {0:.4g}'.format(1-automl.best_loss))\n",
|
||||||
|
"print('Training duration of best run: {0:.4g} s'.format(automl.best_config_train_time))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"DataFrame[evaluation_type: string, confusion_matrix: matrix, accuracy: double, precision: double, recall: double, AUC: double]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
||||||
|
"|evaluation_type| confusion_matrix| accuracy| precision| recall| AUC|\n",
|
||||||
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
||||||
|
"| Classification|1218.0 55.0 \\n1...|0.9453302961275627|0.32926829268292684|0.6136363636363636|0.7852156680711276|\n",
|
||||||
|
"+---------------+--------------------+------------------+-------------------+------------------+------------------+\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"flaml_metrics = predict(automl.model.estimator)\n",
|
||||||
|
"flaml_metrics.show()"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"accelerator": "GPU",
|
||||||
|
"colab": {
|
||||||
|
"collapsed_sections": [],
|
||||||
|
"include_colab_link": true,
|
||||||
|
"name": "Copy of automl_nlp.ipynb",
|
||||||
|
"provenance": []
|
||||||
|
},
|
||||||
|
"gpuClass": "standard",
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "flaml-dev",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.8"
|
||||||
|
},
|
||||||
|
"vscode": {
|
||||||
|
"interpreter": {
|
||||||
|
"hash": "cbbf4d250a3560c7073bd6e01a7ecfe1c772dc45f2100f74412fcaea735f0880"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"widgets": {}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0
|
||||||
|
}
|
8
setup.py
8
setup.py
|
@ -45,7 +45,7 @@ setuptools.setup(
|
||||||
"openml==0.10.2",
|
"openml==0.10.2",
|
||||||
],
|
],
|
||||||
"spark": [
|
"spark": [
|
||||||
"pyspark>=3.0.0",
|
"pyspark>=3.2.0",
|
||||||
"joblibspark>=0.5.0",
|
"joblibspark>=0.5.0",
|
||||||
],
|
],
|
||||||
"test": [
|
"test": [
|
||||||
|
@ -71,7 +71,7 @@ setuptools.setup(
|
||||||
"seqeval",
|
"seqeval",
|
||||||
"pytorch-forecasting>=0.9.0,<=0.10.1",
|
"pytorch-forecasting>=0.9.0,<=0.10.1",
|
||||||
"mlflow",
|
"mlflow",
|
||||||
"pyspark>=3.0.0",
|
"pyspark>=3.2.0",
|
||||||
"joblibspark>=0.5.0",
|
"joblibspark>=0.5.0",
|
||||||
"nbconvert",
|
"nbconvert",
|
||||||
"nbformat",
|
"nbformat",
|
||||||
|
@ -120,8 +120,8 @@ setuptools.setup(
|
||||||
"pytorch-forecasting>=0.9.0",
|
"pytorch-forecasting>=0.9.0",
|
||||||
],
|
],
|
||||||
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
|
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"],
|
||||||
"openai": ["openai==0.27.0", "diskcache", "optuna==2.8.0"],
|
"openai": ["openai==0.23.1", "diskcache", "optuna==2.8.0"],
|
||||||
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.0.0"],
|
"synapse": ["joblibspark>=0.5.0", "optuna==2.8.0", "pyspark>=3.2.0"],
|
||||||
},
|
},
|
||||||
classifiers=[
|
classifiers=[
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
import numpy as np
|
||||||
|
from flaml.automl.utils import len_labels, unique_value_first_index
|
||||||
|
|
||||||
|
|
||||||
|
def test_len_labels():
|
||||||
|
assert len_labels([1, 2, 3]) == 3
|
||||||
|
assert len_labels([1, 2, 3, 1, 2, 3]) == 3
|
||||||
|
assert np.array_equal(len_labels([1, 2, 3], True)[1], [1, 2, 3])
|
||||||
|
assert np.array_equal(len_labels([1, 2, 3, 1, 2, 3], True)[1], [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_unique_value_first_index():
|
||||||
|
label_set, first_index = unique_value_first_index([1, 2, 2, 3])
|
||||||
|
assert np.array_equal(label_set, np.array([1, 2, 3]))
|
||||||
|
assert np.array_equal(first_index, np.array([0, 1, 3]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_len_labels()
|
||||||
|
test_unique_value_first_index()
|
|
@ -0,0 +1,218 @@
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
import pytest
|
||||||
|
import sklearn.datasets as skds
|
||||||
|
from flaml import AutoML
|
||||||
|
from flaml.tune.spark.utils import check_spark
|
||||||
|
|
||||||
|
warnings.simplefilter(action="ignore")
|
||||||
|
if sys.platform == "darwin" or "nt" in os.name:
|
||||||
|
# skip this test if the platform is not linux
|
||||||
|
skip_spark = True
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
import pyspark
|
||||||
|
from pyspark.ml.feature import VectorAssembler
|
||||||
|
from flaml.automl.spark.utils import to_pandas_on_spark
|
||||||
|
|
||||||
|
spark = (
|
||||||
|
pyspark.sql.SparkSession.builder.appName("MyApp")
|
||||||
|
.master("local[1]")
|
||||||
|
.config(
|
||||||
|
"spark.jars.packages",
|
||||||
|
f"com.microsoft.azure:synapseml_2.12:0.10.2,org.apache.hadoop:hadoop-azure:{pyspark.__version__},com.microsoft.azure:azure-storage:8.6.6",
|
||||||
|
)
|
||||||
|
.config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
|
||||||
|
.config("spark.sql.debug.maxToStringFields", "100")
|
||||||
|
.config("spark.driver.extraJavaOptions", "-Xss1m")
|
||||||
|
.config("spark.executor.extraJavaOptions", "-Xss1m")
|
||||||
|
.getOrCreate()
|
||||||
|
)
|
||||||
|
# spark.sparkContext.setLogLevel("ERROR")
|
||||||
|
spark_available, _ = check_spark()
|
||||||
|
skip_spark = not spark_available
|
||||||
|
except ImportError:
|
||||||
|
skip_spark = True
|
||||||
|
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
skip_spark, reason="Spark is not installed. Skip all spark tests."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_spark_synapseml_lightgbm(spark=None, task="classification"):
|
||||||
|
if task == "classification":
|
||||||
|
metric = "accuracy"
|
||||||
|
X_train, y_train = skds.load_iris(return_X_y=True, as_frame=True)
|
||||||
|
elif task == "regression":
|
||||||
|
metric = "r2"
|
||||||
|
X_train, y_train = skds.load_diabetes(return_X_y=True, as_frame=True)
|
||||||
|
elif task == "rank":
|
||||||
|
metric = "ndcg@5"
|
||||||
|
sdf = spark.read.format("parquet").load(
|
||||||
|
"wasbs://publicwasb@mmlspark.blob.core.windows.net/lightGBMRanker_test.parquet"
|
||||||
|
)
|
||||||
|
df = to_pandas_on_spark(sdf)
|
||||||
|
X_train = df.drop(["labels"], axis=1)
|
||||||
|
y_train = df["labels"]
|
||||||
|
|
||||||
|
automl_experiment = AutoML()
|
||||||
|
automl_settings = {
|
||||||
|
"time_budget": 10,
|
||||||
|
"metric": metric,
|
||||||
|
"task": task,
|
||||||
|
"estimator_list": ["lgbm_spark"],
|
||||||
|
"log_training_metric": True,
|
||||||
|
"log_file_name": "test_spark_synapseml.log",
|
||||||
|
"model_history": True,
|
||||||
|
"verbose": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
y_train.name = "label"
|
||||||
|
X_train = to_pandas_on_spark(X_train)
|
||||||
|
y_train = to_pandas_on_spark(y_train)
|
||||||
|
|
||||||
|
if task == "rank":
|
||||||
|
automl_settings["groupCol"] = "query"
|
||||||
|
automl_settings["evalAt"] = [1, 3, 5]
|
||||||
|
automl_settings["groups"] = X_train["query"]
|
||||||
|
automl_settings["groups"].name = "groups"
|
||||||
|
X_train = X_train.to_spark(index_col="index")
|
||||||
|
else:
|
||||||
|
columns = X_train.columns
|
||||||
|
feature_cols = [col for col in columns if col != "label"]
|
||||||
|
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")
|
||||||
|
X_train = featurizer.transform(X_train.to_spark(index_col="index"))[
|
||||||
|
"index", "features"
|
||||||
|
]
|
||||||
|
X_train = to_pandas_on_spark(X_train)
|
||||||
|
|
||||||
|
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
|
||||||
|
if task == "classification":
|
||||||
|
print(automl_experiment.classes_)
|
||||||
|
print(automl_experiment.model)
|
||||||
|
print(automl_experiment.config_history)
|
||||||
|
print(automl_experiment.best_model_for_estimator("lgbm_spark"))
|
||||||
|
print(automl_experiment.best_iteration)
|
||||||
|
print(automl_experiment.best_estimator)
|
||||||
|
print(automl_experiment.best_loss)
|
||||||
|
if task != "rank":
|
||||||
|
print(automl_experiment.score(X_train, y_train, metric=metric))
|
||||||
|
del automl_settings["metric"]
|
||||||
|
del automl_settings["model_history"]
|
||||||
|
del automl_settings["log_training_metric"]
|
||||||
|
del automl_settings["verbose"]
|
||||||
|
del automl_settings["estimator_list"]
|
||||||
|
automl_experiment = AutoML(task=task)
|
||||||
|
try:
|
||||||
|
duration = automl_experiment.retrain_from_log(
|
||||||
|
X_train=X_train,
|
||||||
|
y_train=y_train,
|
||||||
|
train_full=True,
|
||||||
|
record_id=0,
|
||||||
|
**automl_settings,
|
||||||
|
)
|
||||||
|
print(duration)
|
||||||
|
print(automl_experiment.model)
|
||||||
|
print(automl_experiment.predict(X_train)[:5])
|
||||||
|
print(y_train.to_numpy()[:5])
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def test_spark_synapseml_classification():
|
||||||
|
_test_spark_synapseml_lightgbm(spark, "classification")
|
||||||
|
|
||||||
|
|
||||||
|
def test_spark_synapseml_regression():
|
||||||
|
_test_spark_synapseml_lightgbm(spark, "regression")
|
||||||
|
|
||||||
|
|
||||||
|
def test_spark_synapseml_rank():
|
||||||
|
_test_spark_synapseml_lightgbm(spark, "rank")
|
||||||
|
|
||||||
|
|
||||||
|
def test_spark_input_df():
|
||||||
|
df = (
|
||||||
|
spark.read.format("csv")
|
||||||
|
.option("header", True)
|
||||||
|
.option("inferSchema", True)
|
||||||
|
.load(
|
||||||
|
"wasbs://publicwasb@mmlspark.blob.core.windows.net/company_bankruptcy_prediction_data.csv"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
train, test = df.randomSplit([0.8, 0.2], seed=1)
|
||||||
|
feature_cols = df.columns[1:]
|
||||||
|
featurizer = VectorAssembler(inputCols=feature_cols, outputCol="features")
|
||||||
|
train_data = featurizer.transform(train)["Bankrupt?", "features"]
|
||||||
|
test_data = featurizer.transform(test)["Bankrupt?", "features"]
|
||||||
|
automl = AutoML()
|
||||||
|
settings = {
|
||||||
|
"time_budget": 30, # total running time in seconds
|
||||||
|
"metric": "roc_auc",
|
||||||
|
"estimator_list": [
|
||||||
|
"lgbm_spark"
|
||||||
|
], # list of ML learners; we tune lightgbm in this example
|
||||||
|
"task": "classification", # task type
|
||||||
|
"log_file_name": "flaml_experiment.log", # flaml log file
|
||||||
|
"seed": 7654321, # random seed
|
||||||
|
}
|
||||||
|
df = to_pandas_on_spark(to_pandas_on_spark(train_data).to_spark(index_col="index"))
|
||||||
|
|
||||||
|
automl.fit(
|
||||||
|
dataframe=df,
|
||||||
|
label="Bankrupt?",
|
||||||
|
labelCol="Bankrupt?",
|
||||||
|
isUnbalance=True,
|
||||||
|
**settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = automl.model.estimator
|
||||||
|
predictions = model.transform(test_data)
|
||||||
|
|
||||||
|
from synapse.ml.train import ComputeModelStatistics
|
||||||
|
|
||||||
|
metrics = ComputeModelStatistics(
|
||||||
|
evaluationMetric="classification",
|
||||||
|
labelCol="Bankrupt?",
|
||||||
|
scoredLabelsCol="prediction",
|
||||||
|
).transform(predictions)
|
||||||
|
metrics.show()
|
||||||
|
except AttributeError:
|
||||||
|
print("No fitted model because of too short training time.")
|
||||||
|
|
||||||
|
# test invalid params
|
||||||
|
settings = {
|
||||||
|
"time_budget": 10, # total running time in seconds
|
||||||
|
"metric": "roc_auc",
|
||||||
|
"estimator_list": [
|
||||||
|
"lgbm"
|
||||||
|
], # list of ML learners; we tune lightgbm in this example
|
||||||
|
"task": "classification", # task type
|
||||||
|
}
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
automl.fit(
|
||||||
|
dataframe=df,
|
||||||
|
label="Bankrupt?",
|
||||||
|
labelCol="Bankrupt?",
|
||||||
|
isUnbalance=True,
|
||||||
|
**settings,
|
||||||
|
)
|
||||||
|
assert "No estimator is left." in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_spark_synapseml_classification()
|
||||||
|
test_spark_synapseml_regression()
|
||||||
|
test_spark_synapseml_rank()
|
||||||
|
test_spark_input_df()
|
||||||
|
|
||||||
|
# import cProfile
|
||||||
|
# import pstats
|
||||||
|
# from pstats import SortKey
|
||||||
|
|
||||||
|
# cProfile.run("test_spark_input_df()", "test_spark_input_df.profile")
|
||||||
|
# p = pstats.Stats("test_spark_input_df.profile")
|
||||||
|
# p.strip_dirs().sort_stats(SortKey.CUMULATIVE).print_stats("utils.py")
|
|
@ -2,25 +2,27 @@ import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pyspark
|
|
||||||
import pytest
|
import pytest
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
from flaml import AutoML
|
from flaml import AutoML
|
||||||
from flaml.tune.spark.utils import check_spark
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from test.spark.custom_mylearner import *
|
from test.spark.custom_mylearner import *
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from custom_mylearner import *
|
from custom_mylearner import *
|
||||||
|
|
||||||
from flaml.tune.spark.mylearner import lazy_metric
|
try:
|
||||||
|
import pyspark
|
||||||
|
from flaml.tune.spark.utils import check_spark
|
||||||
|
from flaml.tune.spark.mylearner import lazy_metric
|
||||||
|
|
||||||
os.environ["FLAML_MAX_CONCURRENT"] = "10"
|
os.environ["FLAML_MAX_CONCURRENT"] = "10"
|
||||||
|
spark = pyspark.sql.SparkSession.builder.appName("App4OvertimeTest").getOrCreate()
|
||||||
spark = pyspark.sql.SparkSession.builder.appName("App4OvertimeTest").getOrCreate()
|
spark_available, _ = check_spark()
|
||||||
spark_available, _ = check_spark()
|
skip_spark = not spark_available
|
||||||
skip_spark = not spark_available
|
except ImportError:
|
||||||
|
skip_spark = True
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
skip_spark, reason="Spark is not installed. Skip all spark tests."
|
skip_spark, reason="Spark is not installed. Skip all spark tests."
|
||||||
|
|
|
@ -48,6 +48,7 @@ def test_tune_spark():
|
||||||
time_budget_s=5,
|
time_budget_s=5,
|
||||||
use_spark=True,
|
use_spark=True,
|
||||||
verbose=3,
|
verbose=3,
|
||||||
|
n_concurrent_trials=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("Best hyperparameters found were: ", analysis.best_config)
|
# print("Best hyperparameters found were: ", analysis.best_config)
|
||||||
|
|
|
@ -1,16 +1,32 @@
|
||||||
from flaml.tune.spark.utils import (
|
import numpy as np
|
||||||
with_parameters,
|
import pandas as pd
|
||||||
check_spark,
|
|
||||||
get_n_cpus,
|
|
||||||
get_broadcast_data,
|
|
||||||
)
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from timeit import timeit
|
from timeit import timeit
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
|
||||||
from pyspark.sql import SparkSession
|
from pyspark.sql import SparkSession
|
||||||
import pyspark
|
import pyspark
|
||||||
|
import pyspark.pandas as ps
|
||||||
|
from flaml.tune.spark.utils import (
|
||||||
|
with_parameters,
|
||||||
|
check_spark,
|
||||||
|
get_n_cpus,
|
||||||
|
get_broadcast_data,
|
||||||
|
)
|
||||||
|
from flaml.automl.spark.utils import (
|
||||||
|
to_pandas_on_spark,
|
||||||
|
train_test_split_pyspark,
|
||||||
|
unique_pandas_on_spark,
|
||||||
|
len_labels,
|
||||||
|
unique_value_first_index,
|
||||||
|
iloc_pandas_on_spark,
|
||||||
|
)
|
||||||
|
from flaml.automl.spark.metrics import spark_metric_loss_score
|
||||||
|
from flaml.automl.ml import sklearn_metric_loss_score
|
||||||
|
from pyspark.ml.linalg import Vectors
|
||||||
|
|
||||||
spark_available, _ = check_spark()
|
spark_available, _ = check_spark()
|
||||||
skip_spark = not spark_available
|
skip_spark = not spark_available
|
||||||
|
@ -94,8 +110,317 @@ def test_get_broadcast_data():
|
||||||
assert get_broadcast_data(bc_data) == data
|
assert get_broadcast_data(bc_data) == data
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_pandas_on_spark(capsys):
|
||||||
|
pdf = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
|
||||||
|
psdf = to_pandas_on_spark(pdf)
|
||||||
|
print(psdf)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == " a b\n0 1 4\n1 2 5\n2 3 6\n"
|
||||||
|
assert isinstance(psdf, ps.DataFrame)
|
||||||
|
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
sdf = spark.createDataFrame(pdf)
|
||||||
|
psdf = to_pandas_on_spark(sdf)
|
||||||
|
print(psdf)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == " a b\n0 1 4\n1 2 5\n2 3 6\n"
|
||||||
|
assert isinstance(psdf, ps.DataFrame)
|
||||||
|
|
||||||
|
pds = pd.Series([1, 2, 3])
|
||||||
|
pss = to_pandas_on_spark(pds)
|
||||||
|
print(pss)
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.out == "0 1\n1 2\n2 3\ndtype: int64\n"
|
||||||
|
assert isinstance(pss, ps.Series)
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_test_split_pyspark():
|
||||||
|
pdf = pd.DataFrame({"x": [1, 2, 3, 4], "y": [0, 1, 1, 0]})
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
sdf = spark.createDataFrame(pdf).repartition(1)
|
||||||
|
psdf = to_pandas_on_spark(sdf).spark.repartition(1)
|
||||||
|
train_sdf, test_sdf = train_test_split_pyspark(
|
||||||
|
sdf, test_fraction=0.5, to_pandas_spark=False, seed=1
|
||||||
|
)
|
||||||
|
train_psdf, test_psdf = train_test_split_pyspark(
|
||||||
|
psdf, test_fraction=0.5, stratify_column="y", seed=1
|
||||||
|
)
|
||||||
|
assert isinstance(train_sdf, pyspark.sql.dataframe.DataFrame)
|
||||||
|
assert isinstance(test_sdf, pyspark.sql.dataframe.DataFrame)
|
||||||
|
assert isinstance(train_psdf, ps.DataFrame)
|
||||||
|
assert isinstance(test_psdf, ps.DataFrame)
|
||||||
|
assert train_sdf.count() == 2
|
||||||
|
assert train_psdf.shape[0] == 2
|
||||||
|
print(train_sdf.toPandas())
|
||||||
|
print(test_sdf.toPandas())
|
||||||
|
print(train_psdf.to_pandas())
|
||||||
|
print(test_psdf.to_pandas())
|
||||||
|
|
||||||
|
|
||||||
|
def test_unique_pandas_on_spark():
|
||||||
|
pdf = pd.DataFrame({"x": [1, 2, 2, 3], "y": [0, 1, 1, 0]})
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
sdf = spark.createDataFrame(pdf)
|
||||||
|
psdf = to_pandas_on_spark(sdf)
|
||||||
|
label_set, counts = unique_pandas_on_spark(psdf)
|
||||||
|
assert np.array_equal(label_set, np.array([2, 1, 3]))
|
||||||
|
assert np.array_equal(counts, np.array([2, 1, 1]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_len_labels():
|
||||||
|
y1 = np.array([1, 2, 5, 4, 5])
|
||||||
|
y2 = ps.Series([1, 2, 5, 4, 5])
|
||||||
|
assert len_labels(y1) == 4
|
||||||
|
ll, la = len_labels(y2, return_labels=True)
|
||||||
|
assert ll == 4
|
||||||
|
assert set(la.to_numpy()) == set([1, 2, 5, 4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_unique_value_first_index():
|
||||||
|
y1 = np.array([1, 2, 5, 4, 5])
|
||||||
|
y2 = ps.Series([1, 2, 5, 4, 5])
|
||||||
|
l1, f1 = unique_value_first_index(y1)
|
||||||
|
l2, f2 = unique_value_first_index(y2)
|
||||||
|
assert np.array_equal(l1, np.array([1, 2, 4, 5]))
|
||||||
|
assert np.array_equal(f1, np.array([0, 1, 3, 2]))
|
||||||
|
assert np.array_equal(l2, np.array([1, 2, 5, 4]))
|
||||||
|
assert np.array_equal(f2, np.array([0, 1, 2, 3]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_n_current_trials():
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
sc = spark._jsc.sc()
|
||||||
|
num_executors = (
|
||||||
|
len([executor.host() for executor in sc.statusTracker().getExecutorInfos()]) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_n_current_trials(n_concurrent_trials=0, num_executors=num_executors):
|
||||||
|
try:
|
||||||
|
FLAML_MAX_CONCURRENT = int(os.getenv("FLAML_MAX_CONCURRENT", 0))
|
||||||
|
num_executors = max(num_executors, FLAML_MAX_CONCURRENT, 1)
|
||||||
|
except ValueError:
|
||||||
|
FLAML_MAX_CONCURRENT = 0
|
||||||
|
max_spark_parallelism = (
|
||||||
|
min(spark.sparkContext.defaultParallelism, FLAML_MAX_CONCURRENT)
|
||||||
|
if FLAML_MAX_CONCURRENT > 0
|
||||||
|
else spark.sparkContext.defaultParallelism
|
||||||
|
)
|
||||||
|
max_concurrent = max(1, max_spark_parallelism)
|
||||||
|
n_concurrent_trials = min(
|
||||||
|
n_concurrent_trials if n_concurrent_trials > 0 else num_executors,
|
||||||
|
max_concurrent,
|
||||||
|
)
|
||||||
|
print("n_concurrent_trials:", n_concurrent_trials)
|
||||||
|
return n_concurrent_trials
|
||||||
|
|
||||||
|
os.environ["FLAML_MAX_CONCURRENT"] = "invlaid"
|
||||||
|
assert get_n_current_trials() == num_executors
|
||||||
|
os.environ["FLAML_MAX_CONCURRENT"] = "0"
|
||||||
|
assert get_n_current_trials() == max(num_executors, 1)
|
||||||
|
os.environ["FLAML_MAX_CONCURRENT"] = "4"
|
||||||
|
tmp_max = min(4, spark.sparkContext.defaultParallelism)
|
||||||
|
assert get_n_current_trials() == tmp_max
|
||||||
|
os.environ["FLAML_MAX_CONCURRENT"] = "9999999"
|
||||||
|
assert get_n_current_trials() == spark.sparkContext.defaultParallelism
|
||||||
|
os.environ["FLAML_MAX_CONCURRENT"] = "100"
|
||||||
|
tmp_max = min(100, spark.sparkContext.defaultParallelism)
|
||||||
|
assert get_n_current_trials(1) == 1
|
||||||
|
assert get_n_current_trials(2) == min(2, tmp_max)
|
||||||
|
assert get_n_current_trials(50) == min(50, tmp_max)
|
||||||
|
assert get_n_current_trials(200) == min(200, tmp_max)
|
||||||
|
|
||||||
|
|
||||||
|
def test_iloc_pandas_on_spark():
|
||||||
|
psdf = ps.DataFrame({"x": [1, 2, 2, 3], "y": [0, 1, 1, 0]}, index=[0, 1, 2, 3])
|
||||||
|
psds = ps.Series([1, 2, 2, 3], index=[0, 1, 2, 3])
|
||||||
|
assert iloc_pandas_on_spark(psdf, 0).tolist() == [1, 0]
|
||||||
|
d1 = iloc_pandas_on_spark(psdf, slice(1, 3)).to_pandas()
|
||||||
|
d2 = pd.DataFrame({"x": [2, 2], "y": [1, 1]}, index=[1, 2])
|
||||||
|
assert d1.equals(d2)
|
||||||
|
d1 = iloc_pandas_on_spark(psdf, [1, 3]).to_pandas()
|
||||||
|
d2 = pd.DataFrame({"x": [2, 3], "y": [1, 0]}, index=[0, 1])
|
||||||
|
assert d1.equals(d2)
|
||||||
|
assert iloc_pandas_on_spark(psds, 0) == 1
|
||||||
|
assert iloc_pandas_on_spark(psds, slice(1, 3)).tolist() == [2, 2]
|
||||||
|
assert iloc_pandas_on_spark(psds, [0, 3]).tolist() == [1, 3]
|
||||||
|
|
||||||
|
|
||||||
|
def test_spark_metric_loss_score():
|
||||||
|
spark = SparkSession.builder.getOrCreate()
|
||||||
|
scoreAndLabels = map(
|
||||||
|
lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]),
|
||||||
|
[
|
||||||
|
(0.1, 0.0),
|
||||||
|
(0.1, 1.0),
|
||||||
|
(0.4, 0.0),
|
||||||
|
(0.6, 0.0),
|
||||||
|
(0.6, 1.0),
|
||||||
|
(0.6, 1.0),
|
||||||
|
(0.8, 1.0),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
|
||||||
|
dataset = to_pandas_on_spark(dataset)
|
||||||
|
# test pr_auc
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"pr_auc",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("pr_auc: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.166"
|
||||||
|
# test roc_auc
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"roc_auc",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("roc_auc: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.291"
|
||||||
|
|
||||||
|
scoreAndLabels = [
|
||||||
|
(-28.98343821, -27.0),
|
||||||
|
(20.21491975, 21.5),
|
||||||
|
(-25.98418959, -22.0),
|
||||||
|
(30.69731842, 33.0),
|
||||||
|
(74.69283752, 71.0),
|
||||||
|
]
|
||||||
|
dataset = spark.createDataFrame(scoreAndLabels, ["raw", "label"])
|
||||||
|
dataset = to_pandas_on_spark(dataset)
|
||||||
|
# test rmse
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"rmse",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("rmse: ", metric)
|
||||||
|
assert str(metric)[:5] == "2.842"
|
||||||
|
# test mae
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"mae",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("mae: ", metric)
|
||||||
|
assert str(metric)[:5] == "2.649"
|
||||||
|
# test r2
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"r2",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("r2: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.006"
|
||||||
|
# test mse
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"mse",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("mse: ", metric)
|
||||||
|
assert str(metric)[:5] == "8.079"
|
||||||
|
# test var
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"var",
|
||||||
|
dataset["raw"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("var: ", metric)
|
||||||
|
assert str(metric)[:5] == "-1489"
|
||||||
|
|
||||||
|
predictionAndLabelsWithProbabilities = [
|
||||||
|
(1.0, 1.0, 1.0, [0.1, 0.8, 0.1]),
|
||||||
|
(0.0, 2.0, 1.0, [0.9, 0.05, 0.05]),
|
||||||
|
(0.0, 0.0, 1.0, [0.8, 0.2, 0.0]),
|
||||||
|
(1.0, 1.0, 1.0, [0.3, 0.65, 0.05]),
|
||||||
|
]
|
||||||
|
dataset = spark.createDataFrame(
|
||||||
|
predictionAndLabelsWithProbabilities,
|
||||||
|
["prediction", "label", "weight", "probability"],
|
||||||
|
)
|
||||||
|
dataset = to_pandas_on_spark(dataset)
|
||||||
|
# test logloss
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"log_loss",
|
||||||
|
dataset["probability"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("log_loss: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.968"
|
||||||
|
# test accuracy
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"accuracy",
|
||||||
|
dataset["prediction"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("accuracy: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.25"
|
||||||
|
# test f1
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"f1",
|
||||||
|
dataset["prediction"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("f1: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.333"
|
||||||
|
|
||||||
|
scoreAndLabels = [
|
||||||
|
([0.0, 1.0], [0.0, 2.0]),
|
||||||
|
([0.0, 2.0], [0.0, 1.0]),
|
||||||
|
([], [0.0]),
|
||||||
|
([2.0], [2.0]),
|
||||||
|
([2.0, 0.0], [2.0, 0.0]),
|
||||||
|
([0.0, 1.0, 2.0], [0.0, 1.0]),
|
||||||
|
([1.0], [1.0, 2.0]),
|
||||||
|
]
|
||||||
|
dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
|
||||||
|
dataset = to_pandas_on_spark(dataset)
|
||||||
|
# test micro_f1
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"micro_f1",
|
||||||
|
dataset["prediction"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("micro_f1: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.304"
|
||||||
|
# test macro_f1
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"macro_f1",
|
||||||
|
dataset["prediction"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("macro_f1: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.111"
|
||||||
|
|
||||||
|
scoreAndLabels = [
|
||||||
|
(
|
||||||
|
[1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0],
|
||||||
|
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||||
|
),
|
||||||
|
([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 3.0]),
|
||||||
|
([1.0, 2.0, 3.0, 4.0, 5.0], []),
|
||||||
|
]
|
||||||
|
dataset = spark.createDataFrame(scoreAndLabels, ["prediction", "label"])
|
||||||
|
dataset = to_pandas_on_spark(dataset)
|
||||||
|
# test ap
|
||||||
|
metric = spark_metric_loss_score(
|
||||||
|
"ap",
|
||||||
|
dataset["prediction"],
|
||||||
|
dataset["label"],
|
||||||
|
)
|
||||||
|
print("ap: ", metric)
|
||||||
|
assert str(metric)[:5] == "0.644"
|
||||||
|
# test ndcg
|
||||||
|
# ndcg is tested in synapseML rank tests, so we don't need to test it here
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_with_parameters_spark()
|
# test_with_parameters_spark()
|
||||||
test_get_n_cpus_spark()
|
# test_get_n_cpus_spark()
|
||||||
test_broadcast_code()
|
# test_broadcast_code()
|
||||||
test_get_broadcast_data()
|
# test_get_broadcast_data()
|
||||||
|
# test_train_test_split_pyspark()
|
||||||
|
# test_n_current_trials()
|
||||||
|
# test_len_labels()
|
||||||
|
# test_iloc_pandas_on_spark()
|
||||||
|
test_spark_metric_loss_score()
|
||||||
|
|
Loading…
Reference in New Issue