rename training_function (#327)

* rename training_function

* add docstr

* update docstr

* update docstr and comments

Co-authored-by: Qingyun Wu <qxw5138@psu.edu>
This commit is contained in:
Qingyun Wu 2021-12-06 17:03:43 -05:00 committed by GitHub
parent 3111084c07
commit dd60dbc5eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 10 deletions

View File

@ -109,7 +109,7 @@ def report(_metric=None, **kwargs):
def run( def run(
training_function, evaluation_function,
config: Optional[dict] = None, config: Optional[dict] = None,
low_cost_partial_config: Optional[dict] = None, low_cost_partial_config: Optional[dict] = None,
cat_hp_cost: Optional[dict] = None, cat_hp_cost: Optional[dict] = None,
@ -122,7 +122,7 @@ def run(
min_resource: Optional[float] = None, min_resource: Optional[float] = None,
max_resource: Optional[float] = None, max_resource: Optional[float] = None,
reduction_factor: Optional[float] = None, reduction_factor: Optional[float] = None,
scheduler: Optional = None, scheduler=None,
search_alg=None, search_alg=None,
verbose: Optional[int] = 2, verbose: Optional[int] = 2,
local_dir: Optional[str] = None, local_dir: Optional[str] = None,
@ -162,7 +162,12 @@ def run(
print(analysis.trials[-1].last_result) print(analysis.trials[-1].last_result)
Args: Args:
training_function: A user-defined training function. evaluation_function: A user-defined evaluation function.
It takes a configuration as input, outputs a evaluation
result (can be a numerical value or a dictionary of string
and numerical value pairs) for the input configuration.
For machine learning tasks, it usually involves training and
scoring a machine learning model, e.g., through validation loss.
config: A dictionary to specify the search space. config: A dictionary to specify the search space.
low_cost_partial_config: A dictionary from a subset of low_cost_partial_config: A dictionary from a subset of
controlled dimensions to the initial low-cost values. controlled dimensions to the initial low-cost values.
@ -195,7 +200,9 @@ def run(
needing to re-compute the trial. Must be the same length as needing to re-compute the trial. Must be the same length as
points_to_evaluate. points_to_evaluate.
e.g., e.g.,
.. code-block:: python .. code-block:: python
points_to_evaluate = [ points_to_evaluate = [
{"b": .99, "cost_related": {"a": 3}}, {"b": .99, "cost_related": {"a": 3}},
{"b": .99, "cost_related": {"a": 2}}, {"b": .99, "cost_related": {"a": 2}},
@ -204,7 +211,7 @@ def run(
means that you know the reward for the two configs in means that you know the reward for the two configs in
points_to_evaluate are 3.0 and 1.0 respectively and want to points_to_evaluate are 3.0 and 1.0 respectively and want to
inform run() inform run().
resource_attr: A string to specify the resource dimension used by resource_attr: A string to specify the resource dimension used by
the scheduler via "scheduler". the scheduler via "scheduler".
@ -217,7 +224,7 @@ def run(
in this case when resource_attr is provided, the 'flaml' scheduler will be in this case when resource_attr is provided, the 'flaml' scheduler will be
used, otherwise no scheduler will be used. When set 'flaml', an used, otherwise no scheduler will be used. When set 'flaml', an
authentic scheduler implemented in FLAML will be used. It does not authentic scheduler implemented in FLAML will be used. It does not
require users to report intermediate results in training_function. require users to report intermediate results in evaluation_function.
Find more details abuot this scheduler in this paper Find more details abuot this scheduler in this paper
https://arxiv.org/pdf/1911.04706.pdf). https://arxiv.org/pdf/1911.04706.pdf).
When set 'asha', the input for arguments "resource_attr", When set 'asha', the input for arguments "resource_attr",
@ -225,9 +232,9 @@ def run(
to ASHA's "time_attr", "max_t", "grace_period" and "reduction_factor" to ASHA's "time_attr", "max_t", "grace_period" and "reduction_factor"
respectively. You can also provide a self-defined scheduler instance respectively. You can also provide a self-defined scheduler instance
of the TrialScheduler class. When 'asha' or self-defined scheduler is of the TrialScheduler class. When 'asha' or self-defined scheduler is
used, you usually need to report intermediate results in the training used, you usually need to report intermediate results in the evaluation
function. Please find examples using different types of schedulers function. Please find examples using different types of schedulers
and how to set up the corresponding training functions in and how to set up the corresponding evaluation functions in
test/tune/test_scheduler.py. TODO: point to notebook examples. test/tune/test_scheduler.py. TODO: point to notebook examples.
search_alg: An instance of BlendSearch as the search algorithm search_alg: An instance of BlendSearch as the search algorithm
to be used. The same instance can be used for iterative tuning. to be used. The same instance can be used for iterative tuning.
@ -382,7 +389,7 @@ def run(
) )
_use_ray = True _use_ray = True
return tune.run( return tune.run(
training_function, evaluation_function,
metric=metric, metric=metric,
mode=mode, mode=mode,
search_alg=search_alg, search_alg=search_alg,
@ -423,7 +430,7 @@ def run(
num_trials += 1 num_trials += 1
if verbose: if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}") logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = training_function(trial_to_run.config) result = evaluation_function(trial_to_run.config)
if result is not None: if result is not None:
if isinstance(result, dict): if isinstance(result, dict):
report(**result) report(**result)

View File

@ -893,7 +893,7 @@
" metric=HP_METRIC,\n", " metric=HP_METRIC,\n",
" mode=MODE,\n", " mode=MODE,\n",
" low_cost_partial_config={\"num_train_epochs\": 1}),\n", " low_cost_partial_config={\"num_train_epochs\": 1}),\n",
" # uncomment the following if scheduler = 'auto',\n", " # uncomment the following if scheduler = 'asha',\n",
" # max_resource=max_num_epoch, min_resource=1,\n", " # max_resource=max_num_epoch, min_resource=1,\n",
" resources_per_trial={\"gpu\": num_gpus, \"cpu\": num_cpus},\n", " resources_per_trial={\"gpu\": num_gpus, \"cpu\": num_cpus},\n",
" local_dir='logs/',\n", " local_dir='logs/',\n",