forked from mindspore-Ecosystem/mindspore
add keep_bn_fp32 parameter
This commit is contained in:
parent
0bf6717e9a
commit
672244e0ac
|
@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
|
|||
auto prim = std::make_shared<Primitive>(kFusedMulAddNOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim)};
|
||||
inputs.push_back(mul->input(kMulInputNum - lossscale_input_index));
|
||||
inputs.push_back(addn->input(1));
|
||||
inputs.push_back(addn->input(2));
|
||||
// scalar input should be 3rd input
|
||||
inputs.push_back(mul->input(lossscale_input_index));
|
||||
auto fusion_node = graph->NewCNode(inputs);
|
||||
|
@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const {
|
|||
VarPtr Z = std::make_shared<Var>();
|
||||
|
||||
VectorRef mul({prim::kPrimMul, X, Z});
|
||||
VectorRef addn({prim::kPrimAddN, Y, mul});
|
||||
VectorRef addn({prim::kPrimAddN, mul, Y});
|
||||
return addn;
|
||||
}
|
||||
|
||||
|
@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
|
|||
if (addn == nullptr || addn->inputs().size() != kAddNInputNum) {
|
||||
return nullptr;
|
||||
}
|
||||
auto mul_anf = addn->input(2);
|
||||
auto mul_anf = addn->input(1);
|
||||
if (mul_anf == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
|
|||
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
|
||||
"""Get grad with weight_decay."""
|
||||
if if_apply:
|
||||
return op_add((gradient, weight * weight_decay))
|
||||
return op_add((weight * weight_decay, gradient))
|
||||
return gradient
|
||||
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ class Model:
|
|||
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
|
||||
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
|
||||
e.g. Use `loss_scale_manager=None` to set the value.
|
||||
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
|
@ -96,7 +97,10 @@ class Model:
|
|||
self._optimizer = optimizer
|
||||
self._loss_scale_manager = None
|
||||
self._loss_scale_manager_set = False
|
||||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
if 'keep_batchnorm_fp32' in kwargs:
|
||||
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
|
||||
if 'loss_scale_manager' in kwargs:
|
||||
self._loss_scale_manager = kwargs['loss_scale_manager']
|
||||
self._loss_scale_manager_set = True
|
||||
|
@ -112,7 +116,7 @@ class Model:
|
|||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager']:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
raise ValueError(f"Unsupport arg '{arg}'")
|
||||
|
||||
def _build_train_network(self):
|
||||
|
@ -124,12 +128,14 @@ class Model:
|
|||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level,
|
||||
loss_scale_manager=self._loss_scale_manager)
|
||||
loss_scale_manager=self._loss_scale_manager,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
else:
|
||||
network = amp.build_train_network(network,
|
||||
self._optimizer,
|
||||
self._loss_fn,
|
||||
level=self._amp_level)
|
||||
level=self._amp_level,
|
||||
keep_batchnorm_fp32=self._keep_bn_fp32)
|
||||
elif self._loss_fn:
|
||||
network = nn.WithLossCell(network, self._loss_fn)
|
||||
# If need to check if loss_fn is not None, but optimizer is None
|
||||
|
|
|
@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag):
|
|||
@fns
|
||||
def before(a, b):
|
||||
res = mul(scalar, a)
|
||||
res = addn((b, res))
|
||||
res = addn((res, b))
|
||||
return res
|
||||
|
||||
@fns
|
||||
|
|
Loading…
Reference in New Issue