diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 2c4cf69bf64..a47b16d0e02 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -127,7 +127,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, using dynamic loss scale. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. - O2 is recommended on GPU, O3 is recommemded on Ascend. + + O2 is recommended on GPU, O3 is recommended on Ascend. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 2c08fa195b2..79bd6bc90ba 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -61,6 +61,7 @@ class Model: - O0: Do not change. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. + O2 is recommended on GPU, O3 is recommended on Ascend. loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else @@ -115,7 +116,7 @@ class Model: self._build_predict_network() def _process_amp_args(self, kwargs): - if self._amp_level == "O0": + if self._amp_level in ["O0", "O3"]: self._keep_bn_fp32 = False if 'keep_batchnorm_fp32' in kwargs: self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']