!2326 fix the problem of BatchNorm config failure at Amp O3 level and some unexpected indent

Merge pull request !2326 from liangzelang/master
This commit is contained in:
mindspore-ci-bot 2020-06-20 09:20:15 +08:00 committed by Gitee
commit e4298d7d47
2 changed files with 4 additions and 2 deletions

View File

@ -127,7 +127,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32, - O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale. using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommemded on Ascend.
O2 is recommended on GPU, O3 is recommended on Ascend.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.

View File

@ -61,6 +61,7 @@ class Model:
- O0: Do not change. - O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale. - O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'. - O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommended on Ascend. O2 is recommended on GPU, O3 is recommended on Ascend.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
@ -115,7 +116,7 @@ class Model:
self._build_predict_network() self._build_predict_network()
def _process_amp_args(self, kwargs): def _process_amp_args(self, kwargs):
if self._amp_level == "O0": if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs: if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']