!3623 [r0.6][bug][auto_mixed_precision]fix amp doc and eval network build

Merge pull request !3623 from vlne-v1/amp_doc_r0.6
This commit is contained in:
mindspore-ci-bot 2020-07-29 13:02:45 +08:00 committed by Gitee
commit 338a225410
2 changed files with 2 additions and 1 deletions

View File

@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting. scale the loss by LossScaleManager. If set, overwrite the level setting.
""" """

View File

@ -174,7 +174,7 @@ class Model:
else: else:
if self._loss_fn is None: if self._loss_fn is None:
raise ValueError("loss_fn can not be None.") raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"])
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):