forked from mindspore-Ecosystem/mindspore
!2326 fix the problem of BatchNorm config failure at Amp O3 level and some unexpected indent
Merge pull request !2326 from liangzelang/master
This commit is contained in:
commit
e4298d7d47
|
@ -127,7 +127,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
||||||
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
|
||||||
using dynamic loss scale.
|
using dynamic loss scale.
|
||||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||||
O2 is recommended on GPU, O3 is recommemded on Ascend.
|
|
||||||
|
O2 is recommended on GPU, O3 is recommended on Ascend.
|
||||||
|
|
||||||
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
|
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
|
||||||
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
|
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
|
||||||
|
|
|
@ -61,6 +61,7 @@ class Model:
|
||||||
- O0: Do not change.
|
- O0: Do not change.
|
||||||
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
|
||||||
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
|
||||||
|
|
||||||
O2 is recommended on GPU, O3 is recommended on Ascend.
|
O2 is recommended on GPU, O3 is recommended on Ascend.
|
||||||
|
|
||||||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||||
|
@ -115,7 +116,7 @@ class Model:
|
||||||
self._build_predict_network()
|
self._build_predict_network()
|
||||||
|
|
||||||
def _process_amp_args(self, kwargs):
|
def _process_amp_args(self, kwargs):
|
||||||
if self._amp_level == "O0":
|
if self._amp_level in ["O0", "O3"]:
|
||||||
self._keep_bn_fp32 = False
|
self._keep_bn_fp32 = False
|
||||||
if 'keep_batchnorm_fp32' in kwargs:
|
if 'keep_batchnorm_fp32' in kwargs:
|
||||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||||
|
|
Loading…
Reference in New Issue