forked from mindspore-Ecosystem/mindspore
fix doc and eval network build in amp
This commit is contained in:
parent
d9ca3f2e88
commit
dd26d85caf
|
@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
|
||||
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
|
||||
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
|
||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If set, overwrite the level setting.
|
||||
"""
|
||||
|
|
|
@ -174,7 +174,7 @@ class Model:
|
|||
else:
|
||||
if self._loss_fn is None:
|
||||
raise ValueError("loss_fn can not be None.")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
|
||||
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"])
|
||||
self._eval_indexes = [0, 1, 2]
|
||||
|
||||
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
||||
|
|
Loading…
Reference in New Issue