Add supporting using Spark as the backend of parallel training (#846)

* Added spark support for parallel training.

* Added tests and fixed a bug

* Added more tests and updated docs

* Updated setup.py and docs

* Added customize_learner and tests

* Update spark tests and setup.py

* Update docs and verbose

* Update logging, fix issue in cloud notebook

* Update github workflow for spark tests

* Update github workflow

* Remove hack of handling _choice_

* Allow for failures

* Fix tests, update docs

* Update setup.py

* Update Dockerfile for Spark

* Update tests, remove some warnings

* Add test for notebooks, update utils

* Add performance test for Spark

* Fix lru_cache maxsize

* Fix test failures on some platforms

* Fix coverage report failure

* resovle PR comments

* resovle PR comments 2nd round

* resovle PR comments 3rd round

* fix lint and rename test class

* resovle PR comments 4th round

* refactor customize_learner to broadcast_code
This commit is contained in:
Li Jiang 2022-12-24 00:18:49 +08:00 committed by GitHub
parent 4140fc9022
commit da2cd7ca89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1820 additions and 76 deletions

View File

@ -37,6 +37,15 @@ jobs:
export CFLAGS="$CFLAGS -I/usr/local/opt/libomp/include"
export CXXFLAGS="$CXXFLAGS -I/usr/local/opt/libomp/include"
export LDFLAGS="$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp"
- name: On Linux, install Spark stand-alone cluster and PySpark
if: matrix.os == 'ubuntu-latest'
run: |
sudo apt-get update && sudo apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends ca-certificates-java ca-certificates openjdk-17-jdk-headless && sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz?action=download" -O - | tar -xzC /tmp; archive=$(basename "spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz") bash -c "sudo mv -v /tmp/\${archive/%.tgz/} /spark"
pip install --no-cache-dir pyspark>=3.0
export SPARK_HOME=/spark
export PYTHONPATH=/spark/python/lib/py4j-0.10.9.5-src.zip:/spark/python
export PATH=$PATH:$SPARK_HOME/bin
- name: Install packages and dependencies
run: |
python -m pip install --upgrade pip wheel

View File

@ -3,6 +3,16 @@ FROM python:3.7
RUN apt-get update && apt-get -y update
RUN apt-get install -y sudo git npm
# Install Spark
RUN sudo apt-get update && sudo apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
ca-certificates-java ca-certificates openjdk-17-jdk-headless \
wget \
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
RUN wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz?action=download" -O - | tar -xzC /tmp; archive=$(basename "spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz") bash -c "sudo mv -v /tmp/\${archive/%.tgz/} /spark"
ENV SPARK_HOME=/spark \
PYTHONPATH=/spark/python/lib/py4j-0.10.9.5-src.zip:/spark/python
ENV PATH="${PATH}:${SPARK_HOME}/bin"
# Setup user to not run as root
RUN adduser --disabled-password --gecos '' flaml-dev
RUN adduser flaml-dev sudo

View File

@ -4,6 +4,7 @@
# * project root for license information.
import time
import os
import sys
from typing import Callable, Optional, List, Union, Any
import inspect
from functools import partial
@ -54,17 +55,28 @@ from flaml import tune
from flaml.automl.training_log import training_log_reader, training_log_writer
from flaml.default import suggest_learner
from flaml.version import __version__ as flaml_version
from flaml.tune.spark.utils import check_spark, get_broadcast_data
logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)
logger.propagate = False
try:
import mlflow
except ImportError:
mlflow = None
try:
from ray import __version__ as ray_version
assert ray_version >= "1.10.0"
ray_available = True
except (ImportError, AssertionError):
ray_available = False
class SearchState:
@property
@ -331,7 +343,7 @@ class AutoMLState:
return sampled_X_train, sampled_y_train, sampled_weight, groups
@staticmethod
def _compute_with_config_base(config_w_resource, state, estimator):
def _compute_with_config_base(config_w_resource, state, estimator, is_report=True):
if "FLAML_sample_size" in config_w_resource:
sample_size = int(config_w_resource["FLAML_sample_size"])
else:
@ -407,7 +419,8 @@ class AutoMLState:
}
if sampled_weight is not None:
this_estimator_kwargs["sample_weight"] = weight
tune.report(**result)
if is_report is True:
tune.report(**result)
return result
@classmethod
@ -648,7 +661,10 @@ class AutoML(BaseEstimator):
n_concurrent_trials: [Experimental] int, default=1 | The number of
concurrent trials. When n_concurrent_trials > 1, flaml performes
[parallel tuning](../../Use-Cases/Task-Oriented-AutoML#parallel-tuning)
and installation of ray is required: `pip install flaml[ray]`.
and installation of ray or spark is required: `pip install flaml[ray]`
or `pip install flaml[spark]`. Please check
[here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark.
keep_search_state: boolean, default=False | Whether to keep data needed
for model search after fit(). By default the state is deleted for
space saving.
@ -668,6 +684,15 @@ class AutoML(BaseEstimator):
datasets, but will incur more overhead in time.
If dict: the dict contains the keywords arguments to be passed to
[ray.tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html).
use_spark: boolean, default=False | Whether to use spark to run the training
in parallel spark jobs. This can be used to accelerate training on large models
and large datasets, but will incur more overhead in time and thus slow down
training in some cases. GPU training is not supported yet when use_spark is True.
For Spark clusters, by default, we will launch one trial per executor. However,
sometimes we want to launch more trials than the number of executors (e.g., local mode).
In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override
the detected `num_executors`. The final number of concurrent trials will be the minimum
of `n_concurrent_trials` and `num_executors`.
free_mem_ratio: float between 0 and 1, default=0. The free memory ratio to keep during training.
metric_constraints: list, default=[] | The list of metric constraints.
Each element in this list is a 3-tuple, which shall be expressed
@ -759,6 +784,9 @@ class AutoML(BaseEstimator):
settings["append_log"] = settings.get("append_log", False)
settings["min_sample_size"] = settings.get("min_sample_size", MIN_SAMPLE_TRAIN)
settings["use_ray"] = settings.get("use_ray", False)
settings["use_spark"] = settings.get("use_spark", False)
if settings["use_ray"] is not False and settings["use_spark"] is not False:
raise ValueError("use_ray and use_spark cannot be both True.")
settings["free_mem_ratio"] = settings.get("free_mem_ratio", 0)
settings["metric_constraints"] = settings.get("metric_constraints", [])
settings["cv_score_agg_func"] = settings.get("cv_score_agg_func", None)
@ -2081,8 +2109,10 @@ class AutoML(BaseEstimator):
states = self._search_states
mem_res = self._mem_thres
def train(config: dict, state):
def train(config: dict, state, is_report=True):
# handle spark broadcast variables
state = get_broadcast_data(state)
is_report = get_broadcast_data(is_report)
sample_size = config.get("FLAML_sample_size")
config = config.get("ml", config).copy()
if sample_size:
@ -2093,7 +2123,7 @@ class AutoML(BaseEstimator):
del config["learner"]
config.pop("_choice_", None)
result = AutoMLState._compute_with_config_base(
config, state=state, estimator=estimator
config, state=state, estimator=estimator, is_report=is_report
)
else:
# If search algorithm is not in flaml, it does not handle the config constraint, should also tune.report before return
@ -2104,7 +2134,8 @@ class AutoML(BaseEstimator):
"val_loss": np.inf,
"trained_estimator": None,
}
tune.report(**result)
if is_report is True:
tune.report(**result)
return result
if self._use_ray is not False:
@ -2114,6 +2145,10 @@ class AutoML(BaseEstimator):
train,
state=self._state,
)
elif self._use_spark:
from flaml.tune.spark.utils import with_parameters
return with_parameters(train, state=self._state, is_report=False)
else:
return partial(
train,
@ -2174,6 +2209,7 @@ class AutoML(BaseEstimator):
auto_augment=None,
min_sample_size=None,
use_ray=None,
use_spark=None,
free_mem_ratio=0,
metric_constraints=None,
custom_hp=None,
@ -2347,7 +2383,10 @@ class AutoML(BaseEstimator):
n_concurrent_trials: [Experimental] int, default=1 | The number of
concurrent trials. When n_concurrent_trials > 1, flaml performes
[parallel tuning](../../Use-Cases/Task-Oriented-AutoML#parallel-tuning)
and installation of ray is required: `pip install flaml[ray]`.
and installation of ray or spark is required: `pip install flaml[ray]`
or `pip install flaml[spark]`. Please check
[here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark.
keep_search_state: boolean, default=False | Whether to keep data needed
for model search after fit(). By default the state is deleted for
space saving.
@ -2367,6 +2406,10 @@ class AutoML(BaseEstimator):
datasets, but will incur more overhead in time.
If dict: the dict contains the keywords arguments to be passed to
[ray.tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html).
use_spark: boolean, default=False | Whether to use spark to run the training
in parallel spark jobs. This can be used to accelerate training on large models
and large datasets, but will incur more overhead in time and thus slow down
training in some cases.
free_mem_ratio: float between 0 and 1, default=0. The free memory ratio to keep during training.
metric_constraints: list, default=[] | The list of metric constraints.
Each element in this list is a 3-tuple, which shall be expressed
@ -2560,12 +2603,49 @@ class AutoML(BaseEstimator):
)
min_sample_size = min_sample_size or self._settings.get("min_sample_size")
use_ray = self._settings.get("use_ray") if use_ray is None else use_ray
use_spark = self._settings.get("use_spark") if use_spark is None else use_spark
spark_available, spark_error_msg = check_spark()
if use_spark and use_ray is not False:
raise ValueError("use_spark and use_ray cannot be both True.")
elif use_spark and not spark_available:
raise spark_error_msg
old_level = logger.getEffectiveLevel()
self.verbose = verbose
logger.setLevel(50 - verbose * 10)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler(stream=sys.stdout)
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
if not use_ray and not use_spark and n_concurrent_trials > 1:
if ray_available:
logger.warning(
"n_concurrent_trials > 1 is only supported when using Ray or Spark. "
"Ray installed, setting use_ray to True. If you want to use Spark, set use_spark to True."
)
use_ray = True
elif spark_available:
logger.warning(
"n_concurrent_trials > 1 is only supported when using Ray or Spark. "
"Spark installed, setting use_spark to True. If you want to use Ray, set use_ray to True."
)
use_spark = True
else:
logger.warning(
"n_concurrent_trials > 1 is only supported when using Ray or Spark. "
"Neither Ray nor Spark installed, setting n_concurrent_trials to 1."
)
n_concurrent_trials = 1
self._state.n_jobs = n_jobs
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_ray = use_ray or n_concurrent_trials > 1
self._use_spark = use_spark
self._use_ray = use_ray
# use the following condition if we have an estimation of average_trial_time and average_trial_overhead
# self._use_ray = use_ray or n_concurrent_trials > ( average_trail_time + average_trial_overhead) / (average_trial_time)
# self._use_ray = use_ray or n_concurrent_trials > ( average_trial_time + average_trial_overhead) / (average_trial_time)
if self._use_ray is not False:
import ray
@ -2594,6 +2674,11 @@ class AutoML(BaseEstimator):
X_train = ray.get(X_train)
elif isinstance(dataframe, ray.ObjectRef):
dataframe = ray.get(dataframe)
else:
# TODO: Integrate with Spark
self._state.resources_per_trial = (
{"cpu": n_jobs} if n_jobs > 0 else {"cpu": 1}
)
self._state.free_mem_ratio = (
self._settings.get("free_mem_ratio")
if free_mem_ratio is None
@ -2624,14 +2709,6 @@ class AutoML(BaseEstimator):
self._random = np.random.RandomState(RANDOM_SEED)
self._seed = seed if seed is not None else 20
self._learner_selector = learner_selector
old_level = logger.getEffectiveLevel()
self.verbose = verbose
logger.setLevel(50 - verbose * 10)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler()
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
logger.info(f"task = {task}")
self._decide_split_type(split_type)
logger.info(f"Data split method: {self._split_type}")
@ -2927,7 +3004,7 @@ class AutoML(BaseEstimator):
else (
"bs"
if n_concurrent_trials > 1
or self._use_ray is not False
or (self._use_ray is not False or self._use_spark)
and len(estimator_list) > 1
else "cfo"
)
@ -2975,20 +3052,24 @@ class AutoML(BaseEstimator):
logger.setLevel(old_level)
def _search_parallel(self):
try:
from ray import __version__ as ray_version
if self._use_ray is not False:
try:
from ray import __version__ as ray_version
assert ray_version >= "1.10.0"
if ray_version.startswith("1."):
from ray.tune.suggest import ConcurrencyLimiter
else:
from ray.tune.search import ConcurrencyLimiter
import ray
except (ImportError, AssertionError):
raise ImportError(
"use_ray=True requires installation of ray. "
"Please run pip install flaml[ray]"
)
else:
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
assert ray_version >= "1.10.0"
if ray_version.startswith("1."):
from ray.tune.suggest import ConcurrencyLimiter
else:
from ray.tune.search import ConcurrencyLimiter
import ray
except (ImportError, AssertionError):
raise ImportError(
"n_concurrent_trial>1 or use_ray=True requires installation of ray. "
"Please run pip install flaml[ray]"
)
if self._hpo_method in ("cfo", "grid"):
from flaml import CFO as SearchAlgo
elif "bs" == self._hpo_method:
@ -2996,15 +3077,20 @@ class AutoML(BaseEstimator):
elif "random" == self._hpo_method:
from flaml import RandomSearch as SearchAlgo
elif "optuna" == self._hpo_method:
try:
from ray import __version__ as ray_version
if self._use_ray is not False:
try:
from ray import __version__ as ray_version
assert ray_version >= "1.10.0"
if ray_version.startswith("1."):
from ray.tune.suggest.optuna import OptunaSearch as SearchAlgo
else:
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
except (ImportError, AssertionError):
assert ray_version >= "1.10.0"
if ray_version.startswith("1."):
from ray.tune.suggest.optuna import OptunaSearch as SearchAlgo
else:
from ray.tune.search.optuna import OptunaSearch as SearchAlgo
except (ImportError, AssertionError):
from flaml.tune.searcher.suggestion import (
OptunaSearch as SearchAlgo,
)
else:
from flaml.tune.searcher.suggestion import OptunaSearch as SearchAlgo
else:
raise NotImplementedError(
@ -3048,7 +3134,7 @@ class AutoML(BaseEstimator):
allow_empty_config=True,
)
else:
# if self._hpo_method is bo, sometimes the search space and the initial config dimension do not match
# if self._hpo_method is optuna, sometimes the search space and the initial config dimension do not match
# need to remove the extra keys from the search space to be consistent with the initial config
converted_space = SearchAlgo.convert_search_space(space)
@ -3070,21 +3156,40 @@ class AutoML(BaseEstimator):
search_alg = ConcurrencyLimiter(search_alg, self._n_concurrent_trials)
resources_per_trial = self._state.resources_per_trial
analysis = ray.tune.run(
self.trainable,
search_alg=search_alg,
config=space,
metric="val_loss",
mode="min",
resources_per_trial=resources_per_trial,
time_budget_s=time_budget_s,
num_samples=self._max_iter,
verbose=max(self.verbose - 2, 0),
raise_on_failed_trial=False,
keep_checkpoints_num=1,
checkpoint_score_attr="min-val_loss",
**self._use_ray if isinstance(self._use_ray, dict) else {},
)
if self._use_spark:
# use spark as parallel backend
analysis = tune.run(
self.trainable,
search_alg=search_alg,
config=space,
metric="val_loss",
mode="min",
time_budget_s=time_budget_s,
num_samples=self._max_iter,
verbose=max(self.verbose - 2, 0),
use_ray=False,
use_spark=True,
# raise_on_failed_trial=False,
# keep_checkpoints_num=1,
# checkpoint_score_attr="min-val_loss",
)
else:
# use ray as parallel backend
analysis = ray.tune.run(
self.trainable,
search_alg=search_alg,
config=space,
metric="val_loss",
mode="min",
resources_per_trial=resources_per_trial,
time_budget_s=time_budget_s,
num_samples=self._max_iter,
verbose=max(self.verbose - 2, 0),
raise_on_failed_trial=False,
keep_checkpoints_num=1,
checkpoint_score_attr="min-val_loss",
**self._use_ray if isinstance(self._use_ray, dict) else {},
)
# logger.info([trial.last_result for trial in analysis.trials])
trials = sorted(
(
@ -3288,7 +3393,7 @@ class AutoML(BaseEstimator):
num_samples=self._max_iter,
)
else:
# if self._hpo_method is bo, sometimes the search space and the initial config dimension do not match
# if self._hpo_method is optuna, sometimes the search space and the initial config dimension do not match
# need to remove the extra keys from the search space to be consistent with the initial config
converted_space = SearchAlgo.convert_search_space(search_space)
removed_keys = set(search_space.keys()).difference(
@ -3327,6 +3432,7 @@ class AutoML(BaseEstimator):
time_budget_s=time_budget_s,
verbose=max(self.verbose - 3, 0),
use_ray=False,
use_spark=False,
)
time_used = time.time() - start_run_time
better = False
@ -3497,7 +3603,7 @@ class AutoML(BaseEstimator):
self._selected = state = self._search_states[estimator]
state.best_config_sample_size = self._state.data_size[0]
state.best_config = state.init_config[0] if state.init_config else {}
elif self._use_ray is False:
elif self._use_ray is False and self._use_spark is False:
self._search_sequential()
else:
self._search_parallel()
@ -3561,6 +3667,10 @@ class AutoML(BaseEstimator):
and ray.available_resources()["CPU"]
or os.cpu_count()
)
elif self._use_spark:
from flaml.tune.spark.utils import get_n_cpus
n_cpus = get_n_cpus()
else:
n_cpus = os.cpu_count()
ensemble_n_jobs = (

View File

@ -0,0 +1,8 @@
from flaml.tune.spark.utils import (
check_spark,
get_n_cpus,
with_parameters,
broadcast_code,
)
__all__ = ["check_spark", "get_n_cpus", "with_parameters", "broadcast_code"]

191
flaml/tune/spark/utils.py Normal file
View File

@ -0,0 +1,191 @@
import os
import logging
from functools import partial, lru_cache
import textwrap
logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)
try:
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
import pyspark
_have_spark = True
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
except ImportError as e:
logger.debug("Could not import pyspark: %s", e)
_have_spark = False
_spark_major_minor_version = (0, 0)
@lru_cache(maxsize=2)
def check_spark():
"""Check if Spark is installed and running.
Result of the function will be cached since test once is enough. As lru_cache will not
cache exceptions, we don't raise exceptions here but only log a warning message.
Returns:
Return (True, None) if the check passes, otherwise log the exception message and
return (False, Exception(msg)). The exception can be raised by the caller.
"""
logger.warning("\ncheck Spark installation...This line should appear only once.\n")
if not _have_spark:
msg = """use_spark=True requires installation of PySpark. Please run pip install flaml[spark]
and check [here](https://spark.apache.org/docs/latest/api/python/getting_started/install.html)
for more details about installing Spark."""
logger.warning(msg)
return False, ImportError(msg)
if _spark_major_minor_version[0] < 3:
msg = "Spark version must be >= 3.0 to use flaml[spark]"
logger.warning(msg)
return False, ImportError(msg)
try:
SparkSession.builder.getOrCreate()
except RuntimeError as e:
logger.warning(f"\nSparkSession is not available: {e}\n")
return False, RuntimeError(e)
return True, None
def get_n_cpus(node="driver"):
"""Get the number of CPU cores of the given type of node.
Args:
node: string | The type of node to get the number of cores. Can be 'driver' or 'executor'.
Default is 'driver'.
Returns:
An int of the number of CPU cores.
"""
assert node in ["driver", "executor"]
try:
n_cpus = int(
SparkSession.builder.getOrCreate()
.sparkContext.getConf()
.get(f"spark.{node}.cores")
)
except (TypeError, RuntimeError):
n_cpus = os.cpu_count()
return n_cpus
def with_parameters(trainable, **kwargs):
"""Wrapper for trainables to pass arbitrary large data objects.
This wrapper function will store all passed parameters in the Spark
Broadcast variable.
Args:
trainable: Trainable to wrap.
**kwargs: parameters to store in object store.
Returns:
A new function with partial application of the given arguments
and keywords. The given arguments and keywords will be broadcasted
to all the executors.
```python
import pyspark
import flaml
from sklearn.datasets import load_iris
def train(config, data=None):
if isinstance(data, pyspark.broadcast.Broadcast):
data = data.value
print(config, data)
data = load_iris()
with_parameters_train = flaml.tune.spark.utils.with_parameters(train, data=data)
with_parameters_train(config=1)
train(config={"metric": "accuracy"})
```
"""
if not callable(trainable):
raise ValueError(
f"`with_parameters() only works with function trainables`. "
f"Got type: "
f"{type(trainable)}."
)
spark_available, spark_error_msg = check_spark()
if not spark_available:
raise spark_error_msg
spark = SparkSession.builder.getOrCreate()
bc_kwargs = dict()
for k, v in kwargs.items():
bc_kwargs[k] = spark.sparkContext.broadcast(v)
return partial(trainable, **bc_kwargs)
def broadcast_code(custom_code="", file_name="mylearner"):
"""Write customized learner/metric code contents to a file for importing.
It is necessary for using the customized learner/metric in spark backend.
The path of the learner/metric file will be returned.
Args:
custom_code: str, default="" | code contents of the custom learner/metric.
file_name: str, default="mylearner" | file name of the custom learner/metric.
Returns:
The path of the custom code file.
```python
from flaml.tune.spark.utils import broadcast_code
from flaml.automl.model import LGBMEstimator
custom_code = '''
from flaml.automl.model import LGBMEstimator
from flaml import tune
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"num_leaves": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
}
'''
broadcast_code(custom_code=custom_code)
from flaml.tune.spark.mylearner import MyLargeLGBM
assert isinstance(MyLargeLGBM(), LGBMEstimator)
```
"""
flaml_path = os.path.dirname(os.path.abspath(__file__))
custom_code = textwrap.dedent(custom_code)
custom_path = os.path.join(flaml_path, file_name + ".py")
with open(custom_path, "w") as f:
f.write(custom_code)
return custom_path
def get_broadcast_data(broadcast_data):
"""Get the broadcast data from the broadcast variable.
Args:
broadcast_data: pyspark.broadcast.Broadcast | the broadcast variable.
Returns:
The broadcast data.
"""
if _have_spark and isinstance(broadcast_data, pyspark.broadcast.Broadcast):
broadcast_data = broadcast_data.value
return broadcast_data

View File

@ -135,3 +135,41 @@ class SequentialTrialRunner(BaseTrialRunner):
def stop_trial(self, trial):
super().stop_trial(trial)
self.running_trial = None
class SparkTrialRunner(BaseTrialRunner):
"""Implementation of the spark trial runner."""
def __init__(
self,
search_alg=None,
scheduler=None,
metric: Optional[str] = None,
mode: Optional[str] = "min",
):
super().__init__(search_alg, scheduler, metric, mode)
self.running_trials = []
def step(self) -> Trial:
"""Runs one step of the trial event loop.
Callers should typically run this method repeatedly in a loop. They
may inspect or modify the runner's state in between calls to step().
Returns:
a trial to run.
"""
trial_id = Trial.generate_id()
config = self._search_alg.suggest(trial_id)
if config is not None:
trial = SimpleTrial(config, trial_id)
self.add_trial(trial)
trial.set_status(Trial.RUNNING)
self.running_trials.append(trial)
else:
trial = None
return trial
def stop_trial(self, trial):
super().stop_trial(trial)
self.running_trials.remove(trial)

View File

@ -7,6 +7,7 @@ import numpy as np
import datetime
import time
import os
import sys
from collections import defaultdict
try:
@ -15,9 +16,9 @@ try:
assert ray_version >= "1.10.0"
from ray.tune.analysis import ExperimentAnalysis as EA
ray_import = True
ray_available = True
except (ImportError, AssertionError):
ray_import = False
ray_available = False
from .analysis import ExperimentAnalysis as EA
from .trial import Trial
@ -25,6 +26,7 @@ from .result import DEFAULT_METRIC
import logging
logger = logging.getLogger(__name__)
logger.propagate = False
_use_ray = True
_runner = None
_verbose = 0
@ -226,6 +228,7 @@ def run(
metric_constraints: Optional[List[Tuple[str, str, float]]] = None,
max_failure: Optional[int] = 100,
use_ray: Optional[bool] = False,
use_spark: Optional[bool] = False,
use_incumbent_result_in_evaluation: Optional[bool] = None,
log_file_name: Optional[str] = None,
lexico_objectives: Optional[dict] = None,
@ -359,9 +362,10 @@ def run(
print(analysis.trials[-1].last_result)
```
verbose: 0, 1, 2, or 3. Verbosity mode for ray if ray backend is used.
0 = silent, 1 = only status updates, 2 = status and brief trial
results, 3 = status and detailed trial results. Defaults to 2.
verbose: 0, 1, 2, or 3. If ray or spark backend is used, their verbosity will be
affected by this argument. 0 = silent, 1 = only status updates,
2 = status and brief trial results, 3 = status and detailed trial results.
Defaults to 2.
local_dir: A string of the local dir to save ray logs if ray backend is
used; or a local dir to save the tuning log.
num_samples: An integer of the number of configs to try. Defaults to 1.
@ -380,6 +384,7 @@ def run(
max_failure: int | the maximal consecutive number of failures to sample
a trial before the tuning is terminated.
use_ray: A boolean of whether to use ray as the backend.
use_spark: A boolean of whether to use spark as the backend.
log_file_name: A string of the log file name. Default to None.
When set to None:
if local_dir is not given, no log file is created;
@ -423,7 +428,10 @@ def run(
log_file_name = os.path.join(
local_dir, "tune_" + str(datetime.datetime.now()).replace(":", "-") + ".log"
)
if use_ray and use_spark:
raise ValueError("use_ray and use_spark cannot be both True.")
if not use_ray:
_use_ray = False
_verbose = verbose
old_handlers = logger.handlers
old_level = logger.getEffectiveLevel()
@ -443,7 +451,7 @@ def run(
logger.addHandler(logging.FileHandler(log_file_name))
elif not logger.hasHandlers():
# Add the console handler.
_ch = logging.StreamHandler()
_ch = logging.StreamHandler(stream=sys.stdout)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s",
"%m-%d %H:%M:%S",
@ -523,7 +531,7 @@ def run(
if metric is None or mode is None:
metric = metric or search_alg.metric or DEFAULT_METRIC
mode = mode or search_alg.mode
if ray_import:
if ray_available and use_ray:
if ray_version.startswith("1."):
from ray.tune.suggest import ConcurrencyLimiter
else:
@ -567,7 +575,7 @@ def run(
params["grace_period"] = min_resource
if reduction_factor:
params["reduction_factor"] = reduction_factor
if ray_import:
if ray_available:
from ray.tune.schedulers import ASHAScheduler
scheduler = ASHAScheduler(**params)
@ -605,6 +613,142 @@ def run(
_running_trial = old_running_trial
_training_iteration = old_training_iteration
if use_spark:
# parallel run with spark
from flaml.tune.spark.utils import check_spark
spark_available, spark_error_msg = check_spark()
if not spark_available:
raise spark_error_msg
try:
from pyspark.sql import SparkSession
from joblib import Parallel, delayed, parallel_backend
from joblibspark import register_spark
except ImportError as e:
raise ImportError(
f"{e}. Try pip install flaml[spark] or set use_spark=False."
)
from flaml.tune.searcher.suggestion import ConcurrencyLimiter
from .trial_runner import SparkTrialRunner
register_spark()
spark = SparkSession.builder.getOrCreate()
sc = spark._jsc.sc()
num_executors = (
len([executor.host() for executor in sc.statusTracker().getExecutorInfos()])
- 1
)
"""
By default, the number of executors is the number of VMs in the cluster. And we can
launch one trial per executor. However, sometimes we can launch more trials than
the number of executors (e.g., local mode). In this case, we can set the environment
variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`.
`max_concurrent` is the maximum number of concurrent trials defined by `search_alg`,
`FLAML_MAX_CONCURRENT` will also be used to override `max_concurrent` if `search_alg`
is not an instance of `ConcurrencyLimiter`.
The final number of concurrent trials is the minimum of `max_concurrent` and
`num_executors`.
"""
num_executors = max(num_executors, int(os.getenv("FLAML_MAX_CONCURRENT", 1)), 1)
time_start = time.time()
if scheduler:
scheduler.set_search_properties(metric=metric, mode=mode)
if isinstance(search_alg, ConcurrencyLimiter):
max_concurrent = max(1, search_alg.max_concurrent)
else:
max_concurrent = max(1, int(os.getenv("FLAML_MAX_CONCURRENT", 1)))
n_concurrent_trials = min(num_executors, max_concurrent)
with parallel_backend("spark"):
with Parallel(
n_jobs=n_concurrent_trials, verbose=max(0, (verbose - 1) * 50)
) as parallel:
try:
_runner = SparkTrialRunner(
search_alg=search_alg,
scheduler=scheduler,
metric=metric,
mode=mode,
)
num_trials = 0
if time_budget_s is None:
time_budget_s = np.inf
fail = 0
ub = (
len(evaluated_rewards) if evaluated_rewards else 0
) + max_failure
while (
time.time() - time_start < time_budget_s
and (num_samples < 0 or num_trials < num_samples)
and fail < ub
):
while len(_runner.running_trials) < n_concurrent_trials:
# suggest trials for spark
trial_next = _runner.step()
if trial_next:
num_trials += 1
else:
fail += 1 # break with ub consecutive failures
logger.debug(f"consecutive failures is {fail}")
if fail >= ub:
break
trials_to_run = _runner.running_trials
if not trials_to_run:
logger.warning(
f"fail to sample a trial for {max_failure} times in a row, stopping."
)
break
logger.info(
f"Number of trials: {num_trials}/{num_samples}, {len(_runner.running_trials)} RUNNING,"
f" {len(_runner._trials) - len(_runner.running_trials)} TERMINATED"
)
logger.debug(
f"Configs of Trials to run: {[trial_to_run.config for trial_to_run in trials_to_run]}"
)
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
# results = [evaluation_function(trial_to_run.config) for trial_to_run in trials_to_run]
while results:
result = results.pop(0)
trial_to_run = trials_to_run[0]
_runner.running_trial = trial_to_run
if result is not None:
if isinstance(result, dict):
if result:
logger.info(f"Brief result: {result}")
report(**result)
else:
# When the result returned is an empty dict, set the trial status to error
trial_to_run.set_status(Trial.ERROR)
else:
logger.info(
"Brief result: {}".format({metric: result})
)
report(_metric=result)
_runner.stop_trial(trial_to_run)
fail = 0
analysis = ExperimentAnalysis(
_runner.get_trials(),
metric=metric,
mode=mode,
lexico_objectives=lexico_objectives,
)
return analysis
finally:
# recover the global variables in case of nested run
_use_ray = old_use_ray
_verbose = old_verbose
_running_trial = old_running_trial
_training_iteration = old_training_iteration
if not use_ray:
_runner = old_runner
logger.handlers = old_handlers
logger.setLevel(old_level)
# simple sequential run without using tune.run() from ray
time_start = time.time()
_use_ray = False

View File

@ -1041,7 +1041,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 64-bit",
"display_name": "Python 3.8.13 ('syml-py38')",
"language": "python",
"name": "python3"
},
@ -1055,11 +1055,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1"
"hash": "e3d9487e2ef008ade0db1bc293d3206d35cb2b6081faff9f66b40b257b7398f7"
}
}
},

View File

@ -203,7 +203,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.7 ('base')",
"display_name": "Python 3.8.13 ('syml-py38')",
"language": "python",
"name": "python3"
},
@ -217,11 +217,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "e811209110f5aa4d8c2189eeb3ff7b9b4d146931cb9189ef6041ff71605c541d"
"hash": "e3d9487e2ef008ade0db1bc293d3206d35cb2b6081faff9f66b40b257b7398f7"
}
}
},

File diff suppressed because one or more lines are too long

View File

@ -44,6 +44,10 @@ setuptools.setup(
"matplotlib",
"openml==0.10.2",
],
"spark": [
"pyspark>=3.0.0",
"joblibspark>=0.5.0",
],
"test": [
"flake8>=3.8.4",
"thop",
@ -67,6 +71,10 @@ setuptools.setup(
"seqeval",
"pytorch-forecasting>=0.9.0,<=0.10.1",
"mlflow",
"pyspark>=3.0.0",
"joblibspark>=0.5.0",
"nbconvert",
"nbformat",
],
"catboost": ["catboost>=0.26"],
"blendsearch": ["optuna==2.8.0"],

0
test/spark/__init__.py Normal file
View File

View File

@ -0,0 +1,124 @@
from flaml.tune.spark.utils import broadcast_code
custom_code = """
from flaml import tune
from flaml.automl.model import LGBMEstimator, XGBoostSklearnEstimator, SKLearnEstimator
from flaml.automl.data import CLASSIFICATION, get_output_from_log
class MyRegularizedGreedyForest(SKLearnEstimator):
def __init__(self, task="binary", **config):
super().__init__(task, **config)
if task in CLASSIFICATION:
from rgf.sklearn import RGFClassifier
self.estimator_class = RGFClassifier
else:
from rgf.sklearn import RGFRegressor
self.estimator_class = RGFRegressor
@classmethod
def search_space(cls, data_size, task):
space = {
"max_leaf": {
"domain": tune.lograndint(lower=4, upper=data_size[0]),
"init_value": 4,
},
"n_iter": {
"domain": tune.lograndint(lower=1, upper=data_size[0]),
"init_value": 1,
},
"n_tree_search": {
"domain": tune.lograndint(lower=1, upper=32768),
"init_value": 1,
},
"opt_interval": {
"domain": tune.lograndint(lower=1, upper=10000),
"init_value": 100,
},
"learning_rate": {"domain": tune.loguniform(lower=0.01, upper=20.0)},
"min_samples_leaf": {
"domain": tune.lograndint(lower=1, upper=20),
"init_value": 20,
},
}
return space
@classmethod
def size(cls, config):
max_leaves = int(round(config.get("max_leaf", 1)))
n_estimators = int(round(config.get("n_iter", 1)))
return (max_leaves * 3 + (max_leaves - 1) * 4 + 1.0) * n_estimators * 8
@classmethod
def cost_relative2lgbm(cls):
return 1.0
class MyLargeXGB(XGBoostSklearnEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"max_leaves": {
"domain": tune.lograndint(lower=4, upper=3276),
"init_value": 3276,
"low_cost_init_value": 4,
},
}
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"num_leaves": {
"domain": tune.lograndint(lower=4, upper=3276),
"init_value": 3276,
"low_cost_init_value": 4,
},
}
def custom_metric(
X_val,
y_val,
estimator,
labels,
X_train,
y_train,
weight_val=None,
weight_train=None,
config=None,
groups_val=None,
groups_train=None,
):
from sklearn.metrics import log_loss
import time
start = time.time()
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}
"""
_ = broadcast_code(custom_code=custom_code)

19
test/spark/mylearner.py Normal file
View File

@ -0,0 +1,19 @@
from flaml.automl.model import LGBMEstimator
from flaml import tune
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"num_leaves": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
}

108
test/spark/test_automl.py Normal file
View File

@ -0,0 +1,108 @@
import numpy as np
import scipy.sparse
from flaml import AutoML
from flaml.tune.spark.utils import check_spark
import os
import pytest
# For spark, we need to put customized learner in a separate file
if os.path.exists(os.path.join(os.getcwd(), "test", "spark", "mylearner.py")):
try:
from test.spark.mylearner import MyLargeLGBM
skip_my_learner = False
except ImportError:
skip_my_learner = True
MyLargeLGBM = None
else:
MyLargeLGBM = None
skip_my_learner = True
os.environ["FLAML_MAX_CONCURRENT"] = "2"
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
def test_parallel_xgboost(hpo_method=None, data_size=1000):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 10,
"metric": "ap",
"task": "classification",
"log_file_name": "test/sparse_classification.log",
"estimator_list": ["xgboost"],
"log_type": "all",
"n_jobs": 1,
"n_concurrent_trials": 2,
"hpo_method": hpo_method,
"use_spark": True,
}
X_train = scipy.sparse.eye(data_size)
y_train = np.random.randint(2, size=data_size)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.best_model_for_estimator("xgboost"))
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
def test_parallel_xgboost_others():
# use random search as the hpo_method
test_parallel_xgboost(hpo_method="random")
@pytest.mark.skip(
reason="currently not supporting too large data, will support spark dataframe in the future"
)
def test_large_dataset():
test_parallel_xgboost(data_size=90000000)
@pytest.mark.skipif(
skip_my_learner,
reason="please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_custom_learner(data_size=1000):
automl_experiment = AutoML()
automl_experiment.add_learner(learner_name="large_lgbm", learner_class=MyLargeLGBM)
automl_settings = {
"time_budget": 2,
"task": "classification",
"log_file_name": "test/sparse_classification_oom.log",
"estimator_list": ["large_lgbm"],
"log_type": "all",
"n_jobs": 1,
"hpo_method": "random",
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train = scipy.sparse.eye(data_size)
y_train = np.random.randint(2, size=data_size)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.best_model_for_estimator("large_lgbm"))
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
if __name__ == "__main__":
test_parallel_xgboost()
test_parallel_xgboost_others()
# test_large_dataset()
if skip_my_learner:
print(
"please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file"
)
else:
test_custom_learner()

View File

@ -0,0 +1,57 @@
import unittest
from sklearn.datasets import load_wine
from flaml import AutoML
from flaml.tune.spark.utils import check_spark
import os
spark_available, _ = check_spark()
skip_spark = not spark_available
os.environ["FLAML_MAX_CONCURRENT"] = "2"
# To solve pylint issue, we put code for customizing mylearner in a separate file
if os.path.exists(os.path.join(os.getcwd(), "test", "spark", "custom_mylearner.py")):
try:
from test.spark.custom_mylearner import *
from flaml.tune.spark.mylearner import MyRegularizedGreedyForest
skip_my_learner = False
except ImportError:
skip_my_learner = True
else:
skip_my_learner = True
class TestEnsemble(unittest.TestCase):
def setUp(self) -> None:
if skip_spark:
self.skipTest("Spark is not installed. Skip all spark tests.")
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_ensemble(self):
automl = AutoML()
automl.add_learner(learner_name="RGF", learner_class=MyRegularizedGreedyForest)
X_train, y_train = load_wine(return_X_y=True)
settings = {
"time_budget": 5, # total running time in seconds
"estimator_list": ["rf", "xgboost", "catboost"],
"task": "classification", # task type
"sample": True, # whether to subsample training data
"log_file_name": "test/wine.log",
"log_training_metric": True, # whether to log training metric
"ensemble": {
"final_estimator": MyRegularizedGreedyForest(),
"passthrough": False,
},
"n_jobs": 1,
"n_concurrent_trials": 2,
"use_spark": True,
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,76 @@
from flaml.automl.data import load_openml_dataset
from flaml import AutoML
from flaml.tune.spark.utils import check_spark
import os
import pytest
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
os.environ["FLAML_MAX_CONCURRENT"] = "2"
def base_automl(n_concurrent_trials=1, use_ray=False, use_spark=False, verbose=0):
X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=537, data_dir="./"
)
automl = AutoML()
settings = {
"time_budget": 3, # total running time in seconds
"metric": "r2", # primary metrics for regression can be chosen from: ['mae','mse','r2','rmse','mape']
"estimator_list": ["lgbm", "rf", "xgboost"], # list of ML learners
"task": "regression", # task type
"log_file_name": "houses_experiment.log", # flaml log file
"seed": 7654321, # random seed
"n_concurrent_trials": n_concurrent_trials, # the maximum number of concurrent learners
"use_ray": use_ray, # whether to use Ray for distributed training
"use_spark": use_spark, # whether to use Spark for distributed training
"verbose": verbose,
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print("Best accuracy on validation data: {0:.4g}".format(1 - automl.best_loss))
print(
"Training duration of best run: {0:.4g} s".format(automl.best_config_train_time)
)
def test_both_ray_spark():
with pytest.raises(ValueError):
base_automl(n_concurrent_trials=2, use_ray=True, use_spark=True)
def test_verboses():
for verbose in [1, 3, 5]:
base_automl(verbose=verbose)
def test_import_error():
from importlib import reload
import flaml.tune.spark.utils as utils
reload(utils)
utils._have_spark = False
spark_available, spark_error_msg = utils.check_spark()
assert not spark_available
assert isinstance(spark_error_msg, ImportError)
reload(utils)
utils._spark_major_minor_version = (1, 1)
spark_available, spark_error_msg = utils.check_spark()
assert not spark_available
assert isinstance(spark_error_msg, ImportError)
reload(utils)
if __name__ == "__main__":
base_automl()
test_import_error()

View File

@ -0,0 +1,470 @@
import unittest
import numpy as np
import scipy.sparse
from sklearn.datasets import load_iris, load_wine
from flaml import AutoML
from flaml.automl.data import CLASSIFICATION, get_output_from_log
from flaml.automl.training_log import training_log_reader
from flaml.tune.spark.utils import check_spark
import os
spark_available, _ = check_spark()
skip_spark = not spark_available
os.environ["FLAML_MAX_CONCURRENT"] = "2"
# To solve pylint issue, we put code for customizing mylearner in a separate file
if os.path.exists(os.path.join(os.getcwd(), "test", "spark", "custom_mylearner.py")):
try:
from test.spark.custom_mylearner import *
from flaml.tune.spark.mylearner import (
MyRegularizedGreedyForest,
custom_metric,
MyLargeLGBM,
MyLargeXGB,
)
skip_my_learner = False
except ImportError:
skip_my_learner = True
else:
skip_my_learner = True
class TestMultiClass(unittest.TestCase):
def setUp(self) -> None:
if skip_spark:
self.skipTest("Spark is not installed. Skip all spark tests.")
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_custom_learner(self):
automl = AutoML()
automl.add_learner(learner_name="RGF", learner_class=MyRegularizedGreedyForest)
X_train, y_train = load_wine(return_X_y=True)
settings = {
"time_budget": 8, # total running time in seconds
"estimator_list": ["RGF", "lgbm", "rf", "xgboost"],
"task": "classification", # task type
"sample": True, # whether to subsample training data
"log_file_name": "test/wine.log",
"log_training_metric": True, # whether to log training metric
"n_jobs": 1,
"n_concurrent_trials": 2,
"use_spark": True,
"verbose": 4,
}
automl.fit(X_train=X_train, y_train=y_train, **settings)
# print the best model found for RGF
print(automl.best_model_for_estimator("RGF"))
MyRegularizedGreedyForest.search_space = lambda data_size, task: {}
automl.fit(X_train=X_train, y_train=y_train, **settings)
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_custom_metric(self):
df, y = load_iris(return_X_y=True, as_frame=True)
df["label"] = y
automl_experiment = AutoML()
automl_settings = {
"dataframe": df,
"label": "label",
"time_budget": 5,
"eval_method": "cv",
"metric": custom_metric,
"task": "classification",
"log_file_name": "test/iris_custom.log",
"log_training_metric": True,
"log_type": "all",
"n_jobs": 1,
"model_history": True,
"sample_weight": np.ones(len(y)),
"pred_time_limit": 1e-5,
# "ensemble": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
automl_experiment.fit(**automl_settings)
print(automl_experiment.classes_)
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.best_model_for_estimator("rf"))
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
automl_experiment = AutoML()
estimator = automl_experiment.get_estimator_from_log(
automl_settings["log_file_name"], record_id=0, task="multiclass"
)
print(estimator)
(
time_history,
best_valid_loss_history,
valid_loss_history,
config_history,
metric_history,
) = get_output_from_log(
filename=automl_settings["log_file_name"], time_budget=6
)
print(metric_history)
def test_classification(self, as_frame=False):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 4,
"metric": "accuracy",
"task": "classification",
"log_file_name": "test/iris.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True, as_frame=as_frame)
if as_frame:
# test drop column
X_train.columns = range(X_train.shape[1])
X_train[X_train.shape[1]] = np.zeros(len(y_train))
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.classes_)
print(automl_experiment.predict(X_train)[:5])
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.best_model_for_estimator("catboost"))
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
del automl_settings["metric"]
del automl_settings["model_history"]
del automl_settings["log_training_metric"]
automl_experiment = AutoML(task="classification")
duration = automl_experiment.retrain_from_log(
log_file_name=automl_settings["log_file_name"],
X_train=X_train,
y_train=y_train,
train_full=True,
record_id=0,
)
print(duration)
print(automl_experiment.model)
print(automl_experiment.predict_proba(X_train)[:5])
def test_micro_macro_f1(self):
automl_experiment_micro = AutoML()
automl_experiment_macro = AutoML()
automl_settings = {
"time_budget": 2,
"task": "classification",
"log_file_name": "test/micro_macro_f1.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True)
automl_experiment_micro.fit(
X_train=X_train, y_train=y_train, metric="micro_f1", **automl_settings
)
automl_experiment_macro.fit(
X_train=X_train, y_train=y_train, metric="macro_f1", **automl_settings
)
estimator = automl_experiment_macro.model
y_pred = estimator.predict(X_train)
y_pred_proba = estimator.predict_proba(X_train)
from flaml.automl.ml import norm_confusion_matrix, multi_class_curves
print(norm_confusion_matrix(y_train, y_pred))
from sklearn.metrics import roc_curve, precision_recall_curve
print(multi_class_curves(y_train, y_pred_proba, roc_curve))
print(multi_class_curves(y_train, y_pred_proba, precision_recall_curve))
def test_roc_auc_ovr(self):
automl_experiment = AutoML()
X_train, y_train = load_iris(return_X_y=True)
automl_settings = {
"time_budget": 1,
"metric": "roc_auc_ovr",
"task": "classification",
"log_file_name": "test/roc_auc_ovr.log",
"log_training_metric": True,
"n_jobs": 1,
"sample_weight": np.ones(len(y_train)),
"eval_method": "holdout",
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
def test_roc_auc_ovo(self):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 1,
"metric": "roc_auc_ovo",
"task": "classification",
"log_file_name": "test/roc_auc_ovo.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
def test_roc_auc_ovr_weighted(self):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 1,
"metric": "roc_auc_ovr_weighted",
"task": "classification",
"log_file_name": "test/roc_auc_weighted.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
def test_roc_auc_ovo_weighted(self):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 1,
"metric": "roc_auc_ovo_weighted",
"task": "classification",
"log_file_name": "test/roc_auc_weighted.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
def test_sparse_matrix_classification(self):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 2,
"metric": "auto",
"task": "classification",
"log_file_name": "test/sparse_classification.log",
"split_type": "uniform",
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train = scipy.sparse.random(1554, 21, dtype=int)
y_train = np.random.randint(3, size=1554)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.classes_)
print(automl_experiment.predict_proba(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
print(automl_experiment.best_model_for_estimator("extra_tree"))
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def _test_memory_limit(self):
automl_experiment = AutoML()
automl_experiment.add_learner(
learner_name="large_lgbm", learner_class=MyLargeLGBM
)
automl_settings = {
"time_budget": -1,
"task": "classification",
"log_file_name": "test/classification_oom.log",
"estimator_list": ["large_lgbm"],
"log_type": "all",
"hpo_method": "random",
"free_mem_ratio": 0.2,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True, as_frame=True)
automl_experiment.fit(
X_train=X_train, y_train=y_train, max_iter=1, **automl_settings
)
print(automl_experiment.model)
@unittest.skipIf(
skip_my_learner,
"Please run pytest in the root directory of FLAML, i.e., the directory that contains the setup.py file",
)
def test_time_limit(self):
automl_experiment = AutoML()
automl_experiment.add_learner(
learner_name="large_lgbm", learner_class=MyLargeLGBM
)
automl_experiment.add_learner(
learner_name="large_xgb", learner_class=MyLargeXGB
)
automl_settings = {
"time_budget": 0.5,
"task": "classification",
"log_file_name": "test/classification_timeout.log",
"estimator_list": ["catboost"],
"log_type": "all",
"hpo_method": "random",
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True, as_frame=True)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.model.params)
automl_settings["estimator_list"] = ["large_xgb"]
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.model)
automl_settings["estimator_list"] = ["large_lgbm"]
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
print(automl_experiment.model)
def test_fit_w_starting_point(self, as_frame=True):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 3,
"metric": "accuracy",
"task": "classification",
"log_file_name": "test/iris.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True, as_frame=as_frame)
if as_frame:
# test drop column
X_train.columns = range(X_train.shape[1])
X_train[X_train.shape[1]] = np.zeros(len(y_train))
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
automl_val_accuracy = 1.0 - automl_experiment.best_loss
print("Best ML leaner:", automl_experiment.best_estimator)
print("Best hyperparmeter config:", automl_experiment.best_config)
print("Best accuracy on validation data: {0:.4g}".format(automl_val_accuracy))
print(
"Training duration of best run: {0:.4g} s".format(
automl_experiment.best_config_train_time
)
)
starting_points = automl_experiment.best_config_per_estimator
print("starting_points", starting_points)
print("loss of the starting_points", automl_experiment.best_loss_per_estimator)
automl_settings_resume = {
"time_budget": 2,
"metric": "accuracy",
"task": "classification",
"log_file_name": "test/iris_resume.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"log_type": "all",
"starting_points": starting_points,
"n_concurrent_trials": 2,
"use_spark": True,
}
new_automl_experiment = AutoML()
new_automl_experiment.fit(
X_train=X_train, y_train=y_train, **automl_settings_resume
)
new_automl_val_accuracy = 1.0 - new_automl_experiment.best_loss
print("Best ML leaner:", new_automl_experiment.best_estimator)
print("Best hyperparmeter config:", new_automl_experiment.best_config)
print(
"Best accuracy on validation data: {0:.4g}".format(new_automl_val_accuracy)
)
print(
"Training duration of best run: {0:.4g} s".format(
new_automl_experiment.best_config_train_time
)
)
def test_fit_w_starting_points_list(self, as_frame=True):
automl_experiment = AutoML()
automl_settings = {
"time_budget": 3,
"metric": "accuracy",
"task": "classification",
"log_file_name": "test/iris.log",
"log_training_metric": True,
"n_jobs": 1,
"model_history": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
X_train, y_train = load_iris(return_X_y=True, as_frame=as_frame)
if as_frame:
# test drop column
X_train.columns = range(X_train.shape[1])
X_train[X_train.shape[1]] = np.zeros(len(y_train))
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
automl_val_accuracy = 1.0 - automl_experiment.best_loss
print("Best ML leaner:", automl_experiment.best_estimator)
print("Best hyperparmeter config:", automl_experiment.best_config)
print("Best accuracy on validation data: {0:.4g}".format(automl_val_accuracy))
print(
"Training duration of best run: {0:.4g} s".format(
automl_experiment.best_config_train_time
)
)
starting_points = {}
log_file_name = automl_settings["log_file_name"]
with training_log_reader(log_file_name) as reader:
sample_size = 1000
for record in reader.records():
config = record.config
config["FLAML_sample_size"] = sample_size
sample_size += 1000
learner = record.learner
if learner not in starting_points:
starting_points[learner] = []
starting_points[learner].append(config)
max_iter = sum([len(s) for k, s in starting_points.items()])
automl_settings_resume = {
"time_budget": 2,
"metric": "accuracy",
"task": "classification",
"log_file_name": "test/iris_resume_all.log",
"log_training_metric": True,
"n_jobs": 1,
"max_iter": max_iter,
"model_history": True,
"log_type": "all",
"starting_points": starting_points,
"append_log": True,
"n_concurrent_trials": 2,
"use_spark": True,
}
new_automl_experiment = AutoML()
new_automl_experiment.fit(
X_train=X_train, y_train=y_train, **automl_settings_resume
)
new_automl_val_accuracy = 1.0 - new_automl_experiment.best_loss
# print('Best ML leaner:', new_automl_experiment.best_estimator)
# print('Best hyperparmeter config:', new_automl_experiment.best_config)
print(
"Best accuracy on validation data: {0:.4g}".format(new_automl_val_accuracy)
)
# print('Training duration of best run: {0:.4g} s'.format(new_automl_experiment.best_config_train_time))
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,41 @@
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor
from nbconvert.preprocessors import CellExecutionError
from flaml.tune.spark.utils import check_spark
import os
import pytest
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
here = os.path.abspath(os.path.dirname(__file__))
os.environ["FLAML_MAX_CONCURRENT"] = "2"
def run_notebook(input_nb, output_nb="executed_notebook.ipynb", save=False):
try:
file_path = os.path.join(here, os.pardir, os.pardir, "notebook", input_nb)
with open(file_path) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=600, kernel_name="python3")
ep.preprocess(nb, {"metadata": {"path": here}})
except CellExecutionError:
raise
except Exception as e:
print("\nIgnoring below error:\n", e, "\n\n")
finally:
if save:
with open(os.path.join(here, output_nb), "w", encoding="utf-8") as f:
nbformat.write(nb, f)
def test_automl_lightgbm_test():
run_notebook("integrate_spark.ipynb")
if __name__ == "__main__":
test_automl_lightgbm_test()

View File

@ -0,0 +1,110 @@
import sys
from openml.exceptions import OpenMLServerException
from requests.exceptions import ChunkedEncodingError, SSLError
from flaml.tune.spark.utils import check_spark
import os
import pytest
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
os.environ["FLAML_MAX_CONCURRENT"] = "2"
def run_automl(budget=3, dataset_format="dataframe", hpo_method=None):
from flaml.automl.data import load_openml_dataset
import urllib3
performance_check_budget = 3600
if sys.platform == "darwin" or "nt" in os.name or "3.10" not in sys.version:
budget = 3 # revise the buget if the platform is not linux + python 3.10
if budget >= performance_check_budget:
max_iter = 60
performance_check_budget = None
else:
max_iter = None
try:
X_train, X_test, y_train, y_test = load_openml_dataset(
dataset_id=1169, data_dir="test/", dataset_format=dataset_format
)
except (
OpenMLServerException,
ChunkedEncodingError,
urllib3.exceptions.ReadTimeoutError,
SSLError,
) as e:
print(e)
return
""" import AutoML class from flaml package """
from flaml import AutoML
automl = AutoML()
settings = {
"time_budget": budget, # total running time in seconds
"max_iter": max_iter, # maximum number of iterations
"metric": "accuracy", # primary metrics can be chosen from: ['accuracy','roc_auc','roc_auc_ovr','roc_auc_ovo','f1','log_loss','mae','mse','r2']
"task": "classification", # task type
"log_file_name": "airlines_experiment.log", # flaml log file
"seed": 7654321, # random seed
"hpo_method": hpo_method,
"log_type": "all",
"estimator_list": [
"lgbm",
"xgboost",
"xgb_limitdepth",
"rf",
"extra_tree",
], # list of ML learners
"eval_method": "holdout",
"n_concurrent_trials": 2,
"use_spark": True,
}
"""The main flaml automl API"""
automl.fit(X_train=X_train, y_train=y_train, **settings)
""" retrieve best config and best learner """
print("Best ML leaner:", automl.best_estimator)
print("Best hyperparmeter config:", automl.best_config)
print("Best accuracy on validation data: {0:.4g}".format(1 - automl.best_loss))
print(
"Training duration of best run: {0:.4g} s".format(automl.best_config_train_time)
)
print(automl.model.estimator)
print(automl.best_config_per_estimator)
print("time taken to find best model:", automl.time_to_find_best_model)
""" compute predictions of testing dataset """
y_pred = automl.predict(X_test)
print("Predicted labels", y_pred)
print("True labels", y_test)
y_pred_proba = automl.predict_proba(X_test)[:, 1]
""" compute different metric values on testing dataset """
from flaml.automl.ml import sklearn_metric_loss_score
accuracy = 1 - sklearn_metric_loss_score("accuracy", y_pred, y_test)
print("accuracy", "=", accuracy)
print(
"roc_auc", "=", 1 - sklearn_metric_loss_score("roc_auc", y_pred_proba, y_test)
)
print("log_loss", "=", sklearn_metric_loss_score("log_loss", y_pred_proba, y_test))
if performance_check_budget is None:
assert accuracy >= 0.669, "the accuracy of flaml should be larger than 0.67"
def test_automl_array():
run_automl(3, "array", "bs")
def test_automl_performance():
run_automl(3600)
if __name__ == "__main__":
test_automl_array()
test_automl_performance()

58
test/spark/test_tune.py Normal file
View File

@ -0,0 +1,58 @@
import lightgbm as lgb
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from flaml import tune
from flaml.automl.model import LGBMEstimator
from flaml.tune.spark.utils import check_spark
import os
import pytest
spark_available, _ = check_spark()
skip_spark = not spark_available
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
os.environ["FLAML_MAX_CONCURRENT"] = "2"
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25)
def train_breast_cancer(config):
params = LGBMEstimator(**config).params
train_set = lgb.Dataset(X_train, label=y_train)
gbm = lgb.train(params, train_set)
preds = gbm.predict(X_test)
pred_labels = np.rint(preds)
result = {
"mean_accuracy": accuracy_score(y_test, pred_labels),
}
return result
def test_tune_spark():
flaml_lgbm_search_space = LGBMEstimator.search_space(X_train.shape)
config_search_space = {
hp: space["domain"] for hp, space in flaml_lgbm_search_space.items()
}
analysis = tune.run(
train_breast_cancer,
metric="mean_accuracy",
mode="max",
config=config_search_space,
num_samples=-1,
time_budget_s=5,
use_spark=True,
verbose=3,
)
# print("Best hyperparameters found were: ", analysis.best_config)
print("The best trial's result: ", analysis.best_trial.last_result)
if __name__ == "__main__":
test_tune_spark()

101
test/spark/test_utils.py Normal file
View File

@ -0,0 +1,101 @@
from flaml.tune.spark.utils import (
with_parameters,
check_spark,
get_n_cpus,
get_broadcast_data,
)
from functools import partial
from timeit import timeit
import pytest
try:
from pyspark.sql import SparkSession
import pyspark
spark_available, _ = check_spark()
skip_spark = not spark_available
except ImportError:
print("Spark is not installed. Skip all spark tests.")
skip_spark = True
pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)
def test_with_parameters_spark():
def train(config, data=None):
if isinstance(data, pyspark.broadcast.Broadcast):
data = data.value
print(config, len(data))
data = ["a"] * 10**6
with_parameters_train = with_parameters(train, data=data)
partial_train = partial(train, data=data)
spark = SparkSession.builder.getOrCreate()
rdd = spark.sparkContext.parallelize(list(range(2)))
t_partial = timeit(
lambda: rdd.map(lambda x: partial_train(config=x)).collect(), number=5
)
print("python_partial_train: " + str(t_partial))
t_spark = timeit(
lambda: rdd.map(lambda x: with_parameters_train(config=x)).collect(),
number=5,
)
print("spark_with_parameters_train: " + str(t_spark))
# assert t_spark < t_partial
def test_get_n_cpus_spark():
n_cpus = get_n_cpus()
assert isinstance(n_cpus, int)
def test_broadcast_code():
from flaml.tune.spark.utils import broadcast_code
from flaml.automl.model import LGBMEstimator
custom_code = """
from flaml.automl.model import LGBMEstimator
from flaml import tune
class MyLargeLGBM(LGBMEstimator):
@classmethod
def search_space(cls, **params):
return {
"n_estimators": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
"num_leaves": {
"domain": tune.lograndint(lower=4, upper=32768),
"init_value": 32768,
"low_cost_init_value": 4,
},
}
"""
_ = broadcast_code(custom_code=custom_code)
from flaml.tune.spark.mylearner import MyLargeLGBM
assert isinstance(MyLargeLGBM(), LGBMEstimator)
def test_get_broadcast_data():
data = ["a"] * 10
spark = SparkSession.builder.getOrCreate()
bc_data = spark.sparkContext.broadcast(data)
assert get_broadcast_data(bc_data) == data
if __name__ == "__main__":
test_with_parameters_spark()
test_get_n_cpus_spark()
test_broadcast_code()
test_get_broadcast_data()

View File

@ -194,8 +194,8 @@ def test_searcher():
searcher.on_trial_complete("t2", None, True)
searcher.suggest("t3")
searcher.on_trial_complete("t3", {"m": np.nan})
searcher.save("test/tune/optuna.pickle")
searcher.restore("test/tune/optuna.pickle")
searcher.save("test/tune/optuna.pkl")
searcher.restore("test/tune/optuna.pkl")
try:
searcher = BlendSearch(
metric="m", global_search_alg=searcher, metric_constraints=[("c", "<", 1)]

View File

@ -50,6 +50,28 @@ pip install flaml[nlp]
```bash
pip install flaml[ray]
```
* spark
> *Spark support is added in v1.1.0*
```bash
pip install flaml[spark]>=1.1.0
```
For cloud platforms such as [Azure Synapse](https://azure.microsoft.com/en-us/products/synapse-analytics/), Spark clusters are provided.
But you may also need to install `Spark` manually when setting up your own environment.
For latest Ubuntu system, you can install Spark 3.3.0 standalone version with below script.
For more details of installing Spark, please refer to [Spark Doc](https://spark.apache.org/docs/latest/api/python/getting_started/install.html).
```bash
sudo apt-get update && sudo apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends \
ca-certificates-java ca-certificates openjdk-17-jdk-headless \
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*
wget --progress=dot:giga "https://www.apache.org/dyn/closer.lua/spark/spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz?action=download" \
-O - | tar -xzC /tmp; archive=$(basename "spark-3.3.0/spark-3.3.0-bin-hadoop2.tgz") \
bash -c "sudo mv -v /tmp/\${archive/%.tgz/} /spark"
export SPARK_HOME=/spark
export PYTHONPATH=/spark/python/lib/py4j-0.10.9.5-src.zip:/spark/python
export PATH=$PATH:$SPARK_HOME/bin
```
* nni
```bash
pip install flaml[nni]

View File

@ -382,7 +382,11 @@ and have ``split`` and ``get_n_splits`` methods with the same signatures. To di
When you have parallel resources, you can either spend them in training and keep the model search sequential, or perform parallel search. Following scikit-learn, the parameter `n_jobs` specifies how many CPU cores to use for each training job. The number of parallel trials is specified via the parameter `n_concurrent_trials`. By default, `n_jobs=-1, n_concurrent_trials=1`. That is, all the CPU cores (in a single compute node) are used for training a single model and the search is sequential. When you have more resources than what each single training job needs, you can consider increasing `n_concurrent_trials`.
To do parallel tuning, install the `ray` and `blendsearch` options:
FLAML now support two backends for parallel tuning, i.e., `Ray` and `Spark`. You can use either of them, but not both for one tuning job.
#### Parallel tuning with Ray
To do parallel tuning with Ray, install the `ray` and `blendsearch` options:
```bash
pip install flaml[ray,blendsearch]
```
@ -397,6 +401,23 @@ automl.fit(X_train, y_train, n_jobs=4, n_concurrent_trials=4)
```
flaml will perform 4 trials in parallel, each consuming 4 CPU cores. The parallel tuning uses the [BlendSearch](Tune-User-Defined-Function##blendsearch-economical-hyperparameter-optimization-with-blended-search-strategy) algorithm.
#### Parallel tuning with Spark
To do parallel tuning with Spark, install the `spark` and `blendsearch` options:
> *Spark support is added in v1.1.0*
```bash
pip install flaml[spark,blendsearch]>=1.1.0
```
For more details about installing Spark, please refer to [Installation](../Installation#Distributed-tuning).
An example of using Spark for parallel tuning is:
```python
automl.fit(X_train, y_train, n_concurrent_trials=4, use_spark=True)
```
For Spark clusters, by default, we will launch one trial per executor. However, sometimes we want to launch more trials than the number of executors (e.g., local mode). In this case, we can set the environment variable `FLAML_MAX_CONCURRENT` to override the detected `num_executors`. The final number of concurrent trials will be the minimum of `n_concurrent_trials` and `num_executors`. Also, GPU training is not supported yet when use_spark is True.
#### **Guidelines on parallel vs sequential tuning**
**(1) Considerations on wall-clock time.**

View File

@ -290,10 +290,13 @@ The key difference between these two types of constraints is that the calculatio
Related arguments:
- `use_ray`: A boolean of whether to use ray as the backend.
- `use_spark`: A boolean of whether to use spark as the backend.
- `resources_per_trial`: A dictionary of the hardware resources to allocate per trial, e.g., `{'cpu': 1}`. Only valid when using ray backend.
You can perform parallel tuning by specifying `use_ray=True` (requiring flaml[ray] option installed). You can also limit the amount of resources allocated per trial by specifying `resources_per_trial`, e.g., `resources_per_trial={'cpu': 2}`.
You can perform parallel tuning by specifying `use_ray=True` (requiring flaml[ray] option installed) or `use_spark=True`
(requiring flaml[spark] option installed). You can also limit the amount of resources allocated per trial by specifying `resources_per_trial`,
e.g., `resources_per_trial={'cpu': 2}` when `use_ray=True`.
```python
# require: pip install flaml[ray]
@ -311,6 +314,21 @@ print(analysis.best_trial.last_result) # the best trial's result
print(analysis.best_config) # the best config
```
```python
# require: pip install flaml[spark]
analysis = tune.run(
evaluate_config, # the function to evaluate a config
config=config_search_space, # the search space defined
metric="score",
mode="min", # the optimization mode, "min" or "max"
num_samples=-1, # the maximal number of configs to try, -1 means infinite
time_budget_s=10, # the time budget in seconds
use_spark=True,
)
print(analysis.best_trial.last_result) # the best trial's result
print(analysis.best_config) # the best config
```
**A headsup about computation overhead.** When parallel tuning is used, there will be a certain amount of computation overhead in each trial. In case each trial's original cost is much smaller than the overhead, parallel tuning can underperform sequential tuning. Sequential tuning is recommended when compute resource is limited, and each trial can consume all the resources.