!29473 bugfix: Model amp args differ from amp.build_train_network

Merge pull request !29473 from wangnan39/bugfix_Model_amp_config
This commit is contained in:
i-robot 2022-01-24 11:27:27 +00:00 committed by Gitee
commit 597b0ce832
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 11 additions and 17 deletions

View File

@ -179,7 +179,7 @@ class Model:
self._optimizer = optimizer
self._loss_scale_manager = None
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._keep_bn_fp32 = None
self._check_kwargs(kwargs)
self._amp_level = amp_level
self._boost_level = boost_level
@ -214,8 +214,6 @@ class Model:
raise ValueError("For 'Model', the '**kwargs' argument should be empty when network is a GraphCell.")
def _process_amp_args(self, kwargs):
if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
@ -270,21 +268,17 @@ class Model:
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
if self._optimizer:
amp_config = {}
if self._loss_scale_manager_set:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
boost_level=self._boost_level,
loss_scale_manager=self._loss_scale_manager,
keep_batchnorm_fp32=self._keep_bn_fp32)
else:
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
boost_level=self._boost_level,
keep_batchnorm_fp32=self._keep_bn_fp32)
amp_config['loss_scale_manager'] = self._loss_scale_manager
if self._keep_bn_fp32 is not None:
amp_config['keep_batchnorm_fp32'] = self._keep_bn_fp32
network = amp.build_train_network(network,
self._optimizer,
self._loss_fn,
level=self._amp_level,
boost_level=self._boost_level,
**amp_config)
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None