rename training_function (#327)

* rename training_function

* add docstr

* update docstr

* update docstr and comments

Co-authored-by: Qingyun Wu <qxw5138@psu.edu>
This commit is contained in:
Qingyun Wu 2021-12-06 17:03:43 -05:00 committed by GitHub
parent 3111084c07
commit dd60dbc5eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 10 deletions

View File

@ -109,7 +109,7 @@ def report(_metric=None, **kwargs):
def run(
training_function,
evaluation_function,
config: Optional[dict] = None,
low_cost_partial_config: Optional[dict] = None,
cat_hp_cost: Optional[dict] = None,
@ -122,7 +122,7 @@ def run(
min_resource: Optional[float] = None,
max_resource: Optional[float] = None,
reduction_factor: Optional[float] = None,
scheduler: Optional = None,
scheduler=None,
search_alg=None,
verbose: Optional[int] = 2,
local_dir: Optional[str] = None,
@ -162,7 +162,12 @@ def run(
print(analysis.trials[-1].last_result)
Args:
training_function: A user-defined training function.
evaluation_function: A user-defined evaluation function.
It takes a configuration as input, outputs a evaluation
result (can be a numerical value or a dictionary of string
and numerical value pairs) for the input configuration.
For machine learning tasks, it usually involves training and
scoring a machine learning model, e.g., through validation loss.
config: A dictionary to specify the search space.
low_cost_partial_config: A dictionary from a subset of
controlled dimensions to the initial low-cost values.
@ -195,7 +200,9 @@ def run(
needing to re-compute the trial. Must be the same length as
points_to_evaluate.
e.g.,
.. code-block:: python
points_to_evaluate = [
{"b": .99, "cost_related": {"a": 3}},
{"b": .99, "cost_related": {"a": 2}},
@ -204,7 +211,7 @@ def run(
means that you know the reward for the two configs in
points_to_evaluate are 3.0 and 1.0 respectively and want to
inform run()
inform run().
resource_attr: A string to specify the resource dimension used by
the scheduler via "scheduler".
@ -217,7 +224,7 @@ def run(
in this case when resource_attr is provided, the 'flaml' scheduler will be
used, otherwise no scheduler will be used. When set 'flaml', an
authentic scheduler implemented in FLAML will be used. It does not
require users to report intermediate results in training_function.
require users to report intermediate results in evaluation_function.
Find more details abuot this scheduler in this paper
https://arxiv.org/pdf/1911.04706.pdf).
When set 'asha', the input for arguments "resource_attr",
@ -225,9 +232,9 @@ def run(
to ASHA's "time_attr", "max_t", "grace_period" and "reduction_factor"
respectively. You can also provide a self-defined scheduler instance
of the TrialScheduler class. When 'asha' or self-defined scheduler is
used, you usually need to report intermediate results in the training
used, you usually need to report intermediate results in the evaluation
function. Please find examples using different types of schedulers
and how to set up the corresponding training functions in
and how to set up the corresponding evaluation functions in
test/tune/test_scheduler.py. TODO: point to notebook examples.
search_alg: An instance of BlendSearch as the search algorithm
to be used. The same instance can be used for iterative tuning.
@ -382,7 +389,7 @@ def run(
)
_use_ray = True
return tune.run(
training_function,
evaluation_function,
metric=metric,
mode=mode,
search_alg=search_alg,
@ -423,7 +430,7 @@ def run(
num_trials += 1
if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = training_function(trial_to_run.config)
result = evaluation_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
report(**result)

View File

@ -893,7 +893,7 @@
" metric=HP_METRIC,\n",
" mode=MODE,\n",
" low_cost_partial_config={\"num_train_epochs\": 1}),\n",
" # uncomment the following if scheduler = 'auto',\n",
" # uncomment the following if scheduler = 'asha',\n",
" # max_resource=max_num_epoch, min_resource=1,\n",
" resources_per_trial={\"gpu\": num_gpus, \"cpu\": num_cpus},\n",
" local_dir='logs/',\n",