tune api for schedulers (#322)

* revise api and tests

* rename prune_attr

* update finetune notebook

* add scheduler test and notebook

* update tune api for scheduler

* remove scheduler notebook

* Update flaml/tune/tune.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* docstr

* fix imports

* clear notebook output

* fix ray import

* Update flaml/tune/tune.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* improve docstr

* Update flaml/searcher/blendsearch.py

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* remove redundant import

Co-authored-by: Qingyun Wu <qxw5138@psu.edu>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
This commit is contained in:
Qingyun Wu 2021-12-04 21:52:20 -05:00 committed by GitHub
parent 7d269435ae
commit 17b17d084f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 499 additions and 970 deletions

View File

@ -1546,11 +1546,12 @@ class AutoML(BaseEstimator):
return points
@property
def prune_attr(self) -> Optional[str]:
"""Attribute for pruning
def resource_attr(self) -> Optional[str]:
"""Attribute of the resource dimension.
Returns:
A string for the sample size attribute or None
A string for the sample size attribute
(the resource attribute in AutoML) or None.
"""
return "FLAML_sample_size" if self._sample else None
@ -2178,7 +2179,7 @@ class AutoML(BaseEstimator):
low_cost_partial_config=self.low_cost_partial_config,
points_to_evaluate=self.points_to_evaluate,
cat_hp_cost=self.cat_hp_cost,
prune_attr=self.prune_attr,
resource_attr=self.resource_attr,
min_resource=self.min_resource,
max_resource=self.max_resource,
config_constraints=[
@ -2326,11 +2327,11 @@ class AutoML(BaseEstimator):
)
search_space = search_state.search_space
if self._sample:
prune_attr = "FLAML_sample_size"
resource_attr = "FLAML_sample_size"
min_resource = self._min_sample_size
max_resource = self._state.data_size[0]
else:
prune_attr = min_resource = max_resource = None
resource_attr = min_resource = max_resource = None
learner_class = self._state.learner_classes.get(estimator)
if "grid" == self._hpo_method: # for synthetic exp only
points_to_evaluate = []
@ -2362,7 +2363,7 @@ class AutoML(BaseEstimator):
points_to_evaluate=points_to_evaluate,
low_cost_partial_config=low_cost_partial_config,
cat_hp_cost=search_state.cat_hp_cost,
prune_attr=prune_attr,
resource_attr=resource_attr,
min_resource=min_resource,
max_resource=max_resource,
config_constraints=[

View File

@ -45,7 +45,7 @@ class BlendSearch(Searcher):
evaluated_rewards: Optional[List] = None,
time_budget_s: Union[int, float] = None,
num_samples: Optional[int] = None,
prune_attr: Optional[str] = None,
resource_attr: Optional[str] = None,
min_resource: Optional[float] = None,
max_resource: Optional[float] = None,
reduction_factor: Optional[float] = None,
@ -91,17 +91,10 @@ class BlendSearch(Searcher):
points_to_evaluate.
time_budget_s: int or float | Time budget in seconds.
num_samples: int | The number of configs to try.
prune_attr: A string of the attribute used for pruning.
Not necessarily in space.
When prune_attr is in space, it is a hyperparameter, e.g.,
'n_iters', and the best value is unknown.
When prune_attr is not in space, it is a resource dimension,
e.g., 'sample_size', and the peak performance is assumed
to be at the max_resource.
min_resource: A float of the minimal resource to use for the
prune_attr; only valid if prune_attr is not in space.
max_resource: A float of the maximal resource to use for the
prune_attr; only valid if prune_attr is not in space.
resource_attr: A string to specify the resource dimension and the best
performance is assumed to be at the max_resource.
min_resource: A float of the minimal resource to use for the resource_attr.
max_resource: A float of the maximal resource to use for the resource_attr.
reduction_factor: A float of the reduction factor used for
incremental pruning.
global_search_alg: A Searcher instance as the global search
@ -160,7 +153,7 @@ class BlendSearch(Searcher):
metric,
mode,
space,
prune_attr,
resource_attr,
min_resource,
max_resource,
reduction_factor,
@ -409,7 +402,7 @@ class BlendSearch(Searcher):
if (objective - self._metric_target) * self._ls.metric_op < 0:
self._metric_target = objective
if self._ls.resource:
self._best_resource = config[self._ls.prune_attr]
self._best_resource = config[self._ls.resource_attr]
if thread_id:
if not self._metric_constraint_satisfied:
# no point has been found to satisfy metric constraint
@ -637,7 +630,7 @@ class BlendSearch(Searcher):
# return None
config = self._search_thread_pool[choice].suggest(trial_id)
if not choice and config is not None and self._ls.resource:
config[self._ls.prune_attr] = self.best_resource
config[self._ls.resource_attr] = self.best_resource
elif choice and config is None:
# local search thread finishes
if self._search_thread_pool[choice].converged:
@ -975,7 +968,7 @@ class BlendSearchTuner(BlendSearch, NNITuner):
self._ls.metric,
self._mode,
config,
self._ls.prune_attr,
self._ls.resource_attr,
self._ls.min_resource,
self._ls.max_resource,
self._ls.resource_multiple_factor,

View File

@ -39,7 +39,7 @@ class FLOW2(Searcher):
metric: Optional[str] = None,
mode: Optional[str] = None,
space: Optional[dict] = None,
prune_attr: Optional[str] = None,
resource_attr: Optional[str] = None,
min_resource: Optional[float] = None,
max_resource: Optional[float] = None,
resource_multiple_factor: Optional[float] = 4,
@ -67,17 +67,10 @@ class FLOW2(Searcher):
i.e., the relative cost of the
three choices of 'tree_method' is 1, 1 and 2 respectively.
space: A dictionary to specify the search space.
prune_attr: A string of the attribute used for pruning.
Not necessarily in space.
When prune_attr is in space, it is a hyperparameter, e.g.,
'n_iters', and the best value is unknown.
When prune_attr is not in space, it is a resource dimension,
e.g., 'sample_size', and the peak performance is assumed
to be at the max_resource.
min_resource: A float of the minimal resource to use for the
prune_attr; only valid if prune_attr is not in space.
max_resource: A float of the maximal resource to use for the
prune_attr; only valid if prune_attr is not in space.
resource_attr: A string to specify the resource dimension and the best
performance is assumed to be at the max_resource.
min_resource: A float of the minimal resource to use for the resource_attr.
max_resource: A float of the maximal resource to use for the resource_attr.
resource_multiple_factor: A float of the multiplicative factor
used for increasing resource.
cost_attr: A string of the attribute used for cost.
@ -100,7 +93,7 @@ class FLOW2(Searcher):
self.seed = seed
self.init_config = init_config
self.best_config = flatten_dict(init_config)
self.prune_attr = prune_attr
self.resource_attr = resource_attr
self.min_resource = min_resource
self.resource_multiple_factor = resource_multiple_factor or 4
self.cost_attr = cost_attr
@ -148,11 +141,15 @@ class FLOW2(Searcher):
if not hier:
self._space_keys = sorted(self._tunable_keys)
self.hierarchical = hier
if self.prune_attr and self.prune_attr not in self._space and self.max_resource:
if (
self.resource_attr
and self.resource_attr not in self._space
and self.max_resource
):
self.min_resource = self.min_resource or self._min_resource()
self._resource = self._round(self.min_resource)
if not hier:
self._space_keys.append(self.prune_attr)
self._space_keys.append(self.resource_attr)
else:
self._resource = None
self.incumbent = {}
@ -252,7 +249,7 @@ class FLOW2(Searcher):
if partial_config == self.init_config:
self._reset_times += 1
if self._resource:
config[self.prune_attr] = self.min_resource
config[self.resource_attr] = self.min_resource
return config, space
def create(
@ -264,7 +261,7 @@ class FLOW2(Searcher):
self.metric,
self.mode,
space,
self.prune_attr,
self.resource_attr,
self.min_resource,
self.max_resource,
self.resource_multiple_factor,
@ -328,7 +325,7 @@ class FLOW2(Searcher):
self.incumbent = self.normalize(self.best_config)
self.cost_incumbent = result.get(self.cost_attr)
if self._resource:
self._resource = self.best_config[self.prune_attr]
self._resource = self.best_config[self.resource_attr]
self._num_complete4incumbent = 0
self._cost_complete4incumbent = 0
self._num_proposedby_incumbent = 0
@ -377,7 +374,7 @@ class FLOW2(Searcher):
if self.best_config != config:
self.best_config = config
if self._resource:
self._resource = config[self.prune_attr]
self._resource = config[self.resource_attr]
self.incumbent = self.normalize(self.best_config)
self.cost_incumbent = result.get(self.cost_attr)
self._cost_complete4incumbent = 0
@ -495,18 +492,18 @@ class FLOW2(Searcher):
self._resource = self._round(self._resource * self.resource_multiple_factor)
self.cost_incumbent *= self._resource / old_resource
config = self.best_config.copy()
config[self.prune_attr] = self._resource
config[self.resource_attr] = self._resource
self._direction_tried = None
self._configs[trial_id] = (config, self.step)
return unflatten_dict(config)
def _project(self, config):
"""project normalized config in the feasible region and set prune_attr"""
"""project normalized config in the feasible region and set resource_attr"""
for key in self._bounded_keys:
value = config[key]
config[key] = max(0, min(1, value))
if self._resource:
config[self.prune_attr] = self._resource
config[self.resource_attr] = self._resource
@property
def can_suggest(self) -> bool:
@ -525,7 +522,7 @@ class FLOW2(Searcher):
keys = sorted(config.keys()) if self.hierarchical else self._space_keys
for key in keys:
value = config[key]
if key == self.prune_attr:
if key == self.resource_attr:
value_list.append(value)
else:
# key must be in space
@ -556,7 +553,7 @@ class FLOW2(Searcher):
"""whether the incumbent can reach the incumbent of other."""
config1, config2 = self.best_config, other.best_config
incumbent1, incumbent2 = self.incumbent, other.incumbent
if self._resource and config1[self.prune_attr] > config2[self.prune_attr]:
if self._resource and config1[self.resource_attr] > config2[self.resource_attr]:
# resource will not decrease
return False
for key in self._unordered_cat_hp:

View File

@ -247,7 +247,7 @@ def normalize(
config_norm = {}
for key, value in config.items():
domain = space.get(key)
if domain is None: # e.g., prune_attr
if domain is None: # e.g., resource_attr
config_norm[key] = value
continue
if not callable(getattr(domain, "get_sampler", None)):
@ -405,7 +405,7 @@ def denormalize(
# Handle int (4.6 -> 5)
if isinstance(domain, sample.Integer):
config_denorm[key] = int(round(config_denorm[key]))
else: # prune_attr
else: # resource_attr
config_denorm[key] = value
return config_denorm

View File

@ -17,6 +17,7 @@ try:
except (ImportError, AssertionError):
ray_import = False
from .analysis import ExperimentAnalysis as EA
from .result import DEFAULT_METRIC
import logging
@ -117,11 +118,11 @@ def run(
time_budget_s: Union[int, float] = None,
points_to_evaluate: Optional[List[dict]] = None,
evaluated_rewards: Optional[List] = None,
prune_attr: Optional[str] = None,
resource_attr: Optional[str] = None,
min_resource: Optional[float] = None,
max_resource: Optional[float] = None,
reduction_factor: Optional[float] = None,
report_intermediate_result: Optional[bool] = False,
scheduler: Optional = None,
search_alg=None,
verbose: Optional[int] = 2,
local_dir: Optional[str] = None,
@ -205,21 +206,29 @@ def run(
points_to_evaluate are 3.0 and 1.0 respectively and want to
inform run()
prune_attr: A string of the attribute used for pruning.
Not necessarily in space.
When prune_attr is in space, it is a hyperparameter, e.g.,
'n_iters', and the best value is unknown.
When prune_attr is not in space, it is a resource dimension,
e.g., 'sample_size', and the peak performance is assumed
to be at the max_resource.
min_resource: A float of the minimal resource to use for the
prune_attr; only valid if prune_attr is not in space.
max_resource: A float of the maximal resource to use for the
prune_attr; only valid if prune_attr is not in space.
resource_attr: A string to specify the resource dimension used by
the scheduler via "scheduler".
min_resource: A float of the minimal resource to use for the resource_attr.
max_resource: A float of the maximal resource to use for the resource_attr.
reduction_factor: A float of the reduction factor used for incremental
pruning.
report_intermediate_result: A boolean of whether intermediate results
are reported. If so, early stopping and pruning can be used.
scheduler: A scheduler for executing the experiment. Can be None, 'flaml',
'asha' or a custom instance of the TrialScheduler class. Default is None:
in this case when resource_attr is provided, the 'flaml' scheduler will be
used, otherwise no scheduler will be used. When set 'flaml', an
authentic scheduler implemented in FLAML will be used. It does not
require users to report intermediate results in training_function.
Find more details abuot this scheduler in this paper
https://arxiv.org/pdf/1911.04706.pdf).
When set 'asha', the input for arguments "resource_attr",
"min_resource", "max_resource" and "reduction_factor" will be passed
to ASHA's "time_attr", "max_t", "grace_period" and "reduction_factor"
respectively. You can also provide a self-defined scheduler instance
of the TrialScheduler class. When 'asha' or self-defined scheduler is
used, you usually need to report intermediate results in the training
function. Please find examples using different types of schedulers
and how to set up the corresponding training functions in
test/tune/test_scheduler.py. TODO: point to notebook examples.
search_alg: An instance of BlendSearch as the search algorithm
to be used. The same instance can be used for iterative tuning.
e.g.,
@ -295,6 +304,20 @@ def run(
from ..searcher.blendsearch import BlendSearch
if search_alg is None:
flaml_scheduler_resource_attr = (
flaml_scheduler_min_resource
) = flaml_scheduler_max_resource = flaml_scheduler_reduction_factor = None
if scheduler in (None, "flaml"):
# when scheduler is set 'flaml', we will use a scheduler that is
# authentic to the search algorithms in flaml. After setting up
# the search algorithm accordingly, we need to set scheduler to
# None in case it is later used in the trial runner.
flaml_scheduler_resource_attr = resource_attr
flaml_scheduler_min_resource = min_resource
flaml_scheduler_max_resource = max_resource
flaml_scheduler_reduction_factor = reduction_factor
scheduler = None
search_alg = BlendSearch(
metric=metric or DEFAULT_METRIC,
mode=mode,
@ -305,10 +328,10 @@ def run(
cat_hp_cost=cat_hp_cost,
time_budget_s=time_budget_s,
num_samples=num_samples,
prune_attr=prune_attr,
min_resource=min_resource,
max_resource=max_resource,
reduction_factor=reduction_factor,
resource_attr=flaml_scheduler_resource_attr,
min_resource=flaml_scheduler_min_resource,
max_resource=flaml_scheduler_max_resource,
reduction_factor=flaml_scheduler_reduction_factor,
config_constraints=config_constraints,
metric_constraints=metric_constraints,
)
@ -334,12 +357,11 @@ def run(
searcher.set_search_properties(metric, mode, config, setting)
else:
searcher.set_search_properties(metric, mode, config)
scheduler = None
if report_intermediate_result:
if scheduler == "asha":
params = {}
# scheduler resource_dimension=prune_attr
if prune_attr:
params["time_attr"] = prune_attr
# scheduler resource_dimension=resource_attr
if resource_attr:
params["time_attr"] = resource_attr
if max_resource:
params["max_t"] = max_resource
if min_resource:

View File

@ -1 +1 @@
__version__ = "0.8.2"
__version__ = "0.9.0"

File diff suppressed because one or more lines are too long

View File

@ -286,7 +286,7 @@
"metadata": {},
"outputs": [],
"source": [
"time_budget_s = 600 # time budget in seconds\n",
"time_budget_s = 3600 # time budget in seconds\n",
"gpus_per_trial = 0.5 # number of gpus for each trial; 0.5 means two training jobs can share one gpu\n",
"num_samples = 500 # maximal number of trials\n",
"np.random.seed(7654321)"
@ -315,7 +315,7 @@
" low_cost_partial_config={\"num_epochs\": 1},\n",
" max_resource=max_num_epoch,\n",
" min_resource=1,\n",
" report_intermediate_result=True, # only set to True when intermediate results are reported by tune.report\n",
" scheduler=\"asha\", # need to use tune.report to report intermediate results in training_function \n",
" resources_per_trial={\"cpu\": 1, \"gpu\": gpus_per_trial},\n",
" local_dir='logs/',\n",
" num_samples=num_samples,\n",
@ -325,24 +325,9 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"#trials=44\n",
"time=1193.913584947586\n",
"Best trial config: {'l1': 8, 'l2': 8, 'lr': 0.0008818671030627281, 'num_epochs': 55.9513429004283, 'batch_size': 3}\n",
"Best trial final validation loss: 1.0694482081472874\n",
"Best trial final validation accuracy: 0.6389\n",
"Files already downloaded and verified\n",
"Files already downloaded and verified\n",
"Best trial test set accuracy: 0.6294\n"
]
}
],
"outputs": [],
"source": [
"print(f\"#trials={len(result.trials)}\")\n",
"print(f\"time={time.time()-start_time}\")\n",
@ -390,7 +375,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.12"
"version": "3.8.12"
},
"metadata": {
"interpreter": {

View File

@ -71,7 +71,7 @@ def test_forecast_automl(budget=5):
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.prune_attr)
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)
@ -210,7 +210,7 @@ def test_multivariate_forecast_num(budget=5):
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.prune_attr)
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)
@ -341,7 +341,7 @@ def test_multivariate_forecast_cat(budget=5):
) = get_output_from_log(filename=settings["log_file_name"], time_budget=budget)
for config in config_history:
print(config)
print(automl.prune_attr)
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)

View File

@ -64,7 +64,7 @@ def test_automl(budget=5, dataset_format="dataframe", hpo_method=None):
) = get_output_from_log(filename=settings["log_file_name"], time_budget=6)
for config in config_history:
print(config)
print(automl.prune_attr)
print(automl.resource_attr)
print(automl.max_resource)
print(automl.min_resource)

View File

@ -80,7 +80,7 @@ class TestLogging(unittest.TestCase):
low_cost_partial_config=low_cost_partial_config,
points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost,
prune_attr=automl.prune_attr,
resource_attr=automl.resource_attr,
min_resource=automl.min_resource,
max_resource=automl.max_resource,
config_constraints=[

View File

@ -71,7 +71,7 @@ def test_simple(method=None):
low_cost_partial_config=automl.low_cost_partial_config,
points_to_evaluate=automl.points_to_evaluate,
cat_hp_cost=automl.cat_hp_cost,
prune_attr=automl.prune_attr,
resource_attr=automl.resource_attr,
min_resource=automl.min_resource,
max_resource=automl.max_resource,
time_budget_s=automl._state.time_budget,

View File

@ -239,7 +239,7 @@ def cifar10_main(
low_cost_partial_config={"num_epochs": 1},
max_resource=max_num_epochs,
min_resource=1,
report_intermediate_result=True,
scheduler="asha",
resources_per_trial={"cpu": 1, "gpu": gpus_per_trial},
local_dir="logs/",
num_samples=num_samples,

157
test/tune/test_scheduler.py Normal file
View File

@ -0,0 +1,157 @@
"""Require: pip install flaml[test,ray]
"""
from logging import raiseExceptions
from flaml.scheduler.trial_scheduler import TrialScheduler
import numpy as np
from flaml import tune
import time
def rand_vector_unit_sphere(dim):
"""this function allows you to generate
points that uniformly distribute on
the (dim-1)-sphere.
"""
vec = np.random.normal(0, 1, dim)
mag = np.linalg.norm(vec)
return vec / mag
def simple_obj(config, resource=10000):
config_value_vector = np.array([config["x"], config["y"], config["z"]])
score_sequence = []
for i in range(resource):
a = rand_vector_unit_sphere(3)
a[2] = abs(a[2])
point_projection = np.dot(config_value_vector, a)
score_sequence.append(point_projection)
score_avg = np.mean(np.array(score_sequence))
score_std = np.std(np.array(score_sequence))
score_lb = score_avg - 1.96 * score_std / np.sqrt(resource)
tune.report(samplesize=resource, sphere_projection=score_lb)
def obj_w_intermediate_report(resource, config):
config_value_vector = np.array([config["x"], config["y"], config["z"]])
score_sequence = []
for i in range(resource):
a = rand_vector_unit_sphere(3)
a[2] = abs(a[2])
point_projection = np.dot(config_value_vector, a)
score_sequence.append(point_projection)
if (i + 1) % 100 == 0:
score_avg = np.mean(np.array(score_sequence))
score_std = np.std(np.array(score_sequence))
score_lb = score_avg - 1.96 * score_std / np.sqrt(i + 1)
tune.report(samplesize=i + 1, sphere_projection=score_lb)
def obj_w_suggested_resource(resource_attr, config):
resource = config[resource_attr]
simple_obj(config, resource)
def test_scheduler(scheduler=None):
from functools import partial
resource_attr = "samplesize"
max_resource = 10000
# specify the objective functions
if scheduler is None:
evaluation_obj = simple_obj
elif scheduler == "flaml":
evaluation_obj = partial(obj_w_suggested_resource, resource_attr)
elif scheduler == "asha" or isinstance(scheduler, TrialScheduler):
evaluation_obj = partial(obj_w_intermediate_report, max_resource)
else:
try:
from ray.tune.schedulers import TrialScheduler as RayTuneTrialScheduler
except ImportError:
print(
"skip this condition, which may require TrialScheduler from ray tune, \
as ray tune cannot be imported."
)
return
if isinstance(scheduler, RayTuneTrialScheduler):
evaluation_obj = partial(obj_w_intermediate_report, max_resource)
else:
raise ValueError
analysis = tune.run(
evaluation_obj,
config={
"x": tune.uniform(5, 20),
"y": tune.uniform(0, 10),
"z": tune.uniform(0, 10),
},
metric="sphere_projection",
mode="max",
verbose=1,
resource_attr=resource_attr,
scheduler=scheduler,
max_resource=max_resource,
min_resource=100,
reduction_factor=2,
time_budget_s=1,
num_samples=500,
)
print("Best hyperparameters found were: ", analysis.best_config)
# print(analysis.get_best_trial)
return analysis.best_config
def test_no_scheduler():
best_config = test_scheduler()
print("No scheduler, test error:", abs(10 / 2 - best_config["z"] / 2))
def test_asha_scheduler():
try:
from ray.tune.schedulers import ASHAScheduler
except ImportError:
print("skip the test as ray tune cannot be imported.")
return
best_config = test_scheduler(scheduler="asha")
print("Auto ASHA scheduler, test error:", abs(10 / 2 - best_config["z"] / 2))
def test_custom_scheduler():
try:
from ray.tune.schedulers import HyperBandScheduler
except ImportError:
print("skip the test as ray tune cannot be imported.")
return
my_scheduler = HyperBandScheduler(
time_attr="samplesize", max_t=1000, reduction_factor=2
)
best_config = test_scheduler(scheduler=my_scheduler)
print("Custom ASHA scheduler, test error:", abs(10 / 2 - best_config["z"] / 2))
def test_custom_scheduler_default_time_attr():
try:
from ray.tune.schedulers import ASHAScheduler
except ImportError:
print("skip the test as ray tune cannot be imported.")
return
my_scheduler = ASHAScheduler(max_t=10)
best_config = test_scheduler(scheduler=my_scheduler)
print(
"Custom ASHA scheduler (with ASHA default time attr), test error:",
abs(10 / 2 - best_config["z"] / 2),
)
def test_flaml_scheduler():
best_config = test_scheduler(scheduler="flaml")
print("FLAML scheduler, test error", abs(10 / 2 - best_config["z"] / 2))
if __name__ == "__main__":
test_no_scheduler()
test_asha_scheduler()
test_custom_scheduler()
test_custom_scheduler_default_time_attr()
test_flaml_scheduler()

View File

@ -83,7 +83,7 @@ def _test_xgboost(method="BlendSearch"):
mode="min",
max_resource=max_iter,
min_resource=1,
report_intermediate_result=True,
scheduler="asha",
# You can add "gpu": 0.1 to allocate GPUs
resources_per_trial={"cpu": 1},
local_dir="logs/",