!29473 bugfix: Model amp args differ from amp.build_train_network
Merge pull request !29473 from wangnan39/bugfix_Model_amp_config
This commit is contained in:
commit
597b0ce832
|
@ -179,7 +179,7 @@ class Model:
|
|||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._keep_bn_fp32 = None
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._boost_level = boost_level
|
||||
|
@ -214,8 +214,6 @@ class Model:
|
|||
raise ValueError("For 'Model', the '**kwargs' argument should be empty when network is a GraphCell.")
|
||||
|
||||
def _process_amp_args(self, kwargs):
|
||||
if self._amp_level in ["O0", "O3"]:
|
||||
self._keep_bn_fp32 = False
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
|
@ -270,21 +268,17 @@ class Model:
|
|||
raise ValueError("The argument 'optimizer' can not be None when set 'loss_scale_manager'.")
|
||||
|
||||
if self._optimizer:
|
||||
amp_config = {}
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
boost_level=self._boost_level,
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
boost_level=self._boost_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
amp_config['loss_scale_manager'] = self._loss_scale_manager
|
||||
if self._keep_bn_fp32 is not None:
|
||||
amp_config['keep_batchnorm_fp32'] = self._keep_bn_fp32
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
boost_level=self._boost_level,
|
||||
**amp_config)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
|
Loading…
Reference in New Issue