diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc index f2e26a61771..9191368e2f1 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/less_batch_normalization.cc @@ -20,6 +20,7 @@ namespace mindspore { namespace opt { namespace irpass { namespace { +enum RemoveNodeType { kOtherNode = 0, kOptimizerNode }; const char kLessBatchNormalizationPassName[] = "less_bn"; constexpr auto kValidResidualStructureIndex = 1; constexpr auto kBNParametersStartIndex = 2; @@ -63,6 +64,11 @@ const std::vector ResidualStructureFirstStepPattern{ {kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}}; static const std::vector> kNeedMatchPattern = { ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern}; +const std::set kNeedRemoveNodeSet{ + prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum, + prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam}; +static std::unordered_map> kRemoveIndex{ + {RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}}; bool NeedRemove(const ParameterPtr &a, const std::vector ¶meter_list) { if (a == nullptr) { @@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector ¶meter_ }); } +bool IsNotRealUseNode(const AnfNodePtr &node) { + for (const auto &prim : kNeedRemoveNodeSet) { + if (IsPrimitiveCNode(node, prim)) { + return true; + } + } + return false; +} + +CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + std::vector args; + size_t index = 0; + const auto &inputs = cnode->inputs(); + auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode]; + if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) { + remove_index = kRemoveIndex[RemoveNodeType::kOtherNode]; + } + + (void)std::copy_if( + inputs.begin(), inputs.end(), std::back_inserter(args), + [&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); }); + + (void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple)); + const auto &fg = cnode->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto new_make_tuple = fg->NewCNode(args); + return new_make_tuple; +} + bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) { auto param_output = manager->node_users().find(parameter); if (param_output == manager->node_users().end()) { return true; } - return false; + bool need_remove = true; + auto output_info_list = param_output->second; + for (const auto &output_info : output_info_list) { + const auto &node = output_info.first; + if (IsNotRealUseNode(node)) { + const auto &cnode = node->cast(); + const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode); + manager->Replace(cnode, new_cnode); + continue; + } + need_remove = false; + } + + return need_remove; } void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager, @@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager [&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); auto root_parameters = root_graph->parameters(); + size_t origin_param_count = root_parameters.size(); root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), [&real_remove_parameter_list](const AnfNodePtr &node) { return NeedRemove(node->cast(), real_remove_parameter_list); }), root_parameters.end()); - + size_t remove_param_count = origin_param_count - root_parameters.size(); + size_t hyper_param_count = root_graph->hyper_param_count(); + if (remove_param_count > hyper_param_count) { + MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters."; + return; + } + hyper_param_count = hyper_param_count - remove_param_count; + root_graph->set_hyper_param_count(hyper_param_count); manager->SetParameters(root_graph, root_parameters); } } // namespace diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 3bf1a463beb..f1b667968be 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) std::vector kwarg_list; std::vector pos_arg_indexes; size_t arguments_count = args_spec_list.size(); + if (hyper_param_count_ > arguments_count) { + MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments."; + } for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) { MS_EXCEPTION_IF_NULL(args_spec_list[i]); if (args_spec_list[i]->isa()) { diff --git a/mindspore/nn/acc/less_batch_normalization.py b/mindspore/nn/acc/less_batch_normalization.py index 95db064e540..ee6ddecafbf 100644 --- a/mindspore/nn/acc/less_batch_normalization.py +++ b/mindspore/nn/acc/less_batch_normalization.py @@ -13,8 +13,74 @@ # limitations under the License. # ============================================================================ """less batch normalization""" +import numpy as np +from mindspore import nn +from mindspore.ops import operations as P +from mindspore import Tensor, Parameter +from mindspore import dtype as mstype +from mindspore.common.initializer import initializer from ..cell import Cell + +__all__ = ["LessBN"] + + +class CommonHeadLastFN(Cell): + r""" + The last full normalization layer. + + This layer implements the operation as: + + .. math:: + \text{inputs} = \text{norm}(\text{inputs}) + \text{kernel} = \text{norm}(\text{kernel}) + \text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}), + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32) + >>> net = CommonHeadLastFN(3, 4) + >>> output = net(input) + """ + def __init__(self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True): + super(CommonHeadLastFN, self).__init__() + weight_shape = [out_channels, in_channels] + self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight') + self.x_norm = P.L2Normalize(axis=1) + self.w_norm = P.L2Normalize(axis=1) + self.fc = P.MatMul(transpose_a=False, transpose_b=True) + self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier') + self.has_bias = has_bias + if self.has_bias: + bias_shape = [out_channels] + self.bias_add = P.BiasAdd() + self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias') + + def construct(self, x): + x = self.x_norm(x) + w = self.w_norm(self.weight) + x = self.fc(x, w) + if self.has_bias: + x = self.bias_add(x, self.bias) + x = self.multiplier * x + return x + class LessBN(Cell): """ Reduce the number of BN automatically to improve the network performance @@ -31,6 +97,44 @@ class LessBN(Cell): super(LessBN, self).__init__() self.network = network self.network.set_acc("less_bn") + self.network.update_cell_prefix() + self._convert_to_less_bn_net(self.network) + self.network.add_flags(defer_inline=True) + + def _convert_dense(self, subcell): + """ + convert dense cell to FN cell + """ + prefix = subcell.param_prefix + new_subcell = CommonHeadLastFN(subcell.in_channels, + subcell.out_channels, + subcell.weight, + subcell.bias, + subcell.has_bias) + new_subcell.update_parameters_name(prefix + '.') + + return new_subcell + + def _convert_to_less_bn_net(self, net): + """ + convert network to less_bn network + """ + cells = net.name_cells() + dense_name = [] + dense_list = [] + for name in cells: + subcell = cells[name] + if subcell == net: + continue + elif isinstance(subcell, (nn.Dense)): + dense_name.append(name) + dense_list.append(subcell) + else: + self._convert_to_less_bn_net(subcell) + + if dense_list: + new_subcell = self._convert_dense(dense_list[-1]) + net.insert_child_to_cell(dense_name[-1], new_subcell) def construct(self, *inputs): return self.network(*inputs) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index a92c4b6cf49..97e0466b2a8 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -1048,7 +1048,7 @@ class Cell(Cell_): Some acceleration algorithms may affect the accuracy of the network, please choose carefully. Args: - acc_type (:str:`less_bn`): accelerate algorithm. + acc_type (str): accelerate algorithm. Raises: ValueError: If acc_type is not in the algorithm library. diff --git a/mindspore/nn/optim/ada_grad.py b/mindspore/nn/optim/ada_grad.py index fa27efbcf61..a029502ac29 100644 --- a/mindspore/nn/optim/ada_grad.py +++ b/mindspore/nn/optim/ada_grad.py @@ -157,8 +157,8 @@ class Adagrad(Optimizer): params = self.parameters accum = self.accum grads = self.decay_weight(grads) - grads = self.scale_grad(grads) grads = self.gradients_centralization(grads) + grads = self.scale_grad(grads) lr = self.get_lr() if self.is_group_lr: success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum, diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index de61a152cd6..482b6db4cc5 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -342,9 +342,9 @@ class Adam(Optimizer): moment1 = self.moment1 moment2 = self.moment2 gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) gradients = self.scale_grad(gradients) gradients = self._grad_sparse_indices_deduplicate(gradients) - gradients = self.gradients_centralization(gradients) lr = self.get_lr() beta1_power = self.beta1_power * self.beta1 diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index 3630d98cb5f..93b2defe2da 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -222,9 +222,9 @@ class FTRL(Optimizer): moments = self.moments linear = self.linear grads = self.decay_weight(grads) + grads = self.gradients_centralization(grads) grads = self.scale_grad(grads) grads = self._grad_sparse_indices_deduplicate(grads) - grads = self.gradients_centralization(grads) lr = self.get_lr() success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull, diff --git a/mindspore/nn/optim/lazyadam.py b/mindspore/nn/optim/lazyadam.py index c21c095612f..8eca2300582 100644 --- a/mindspore/nn/optim/lazyadam.py +++ b/mindspore/nn/optim/lazyadam.py @@ -258,9 +258,9 @@ class LazyAdam(Optimizer): def construct(self, gradients): gradients = self.decay_weight(gradients) + gradients = self.gradients_centralization(gradients) gradients = self.scale_grad(gradients) gradients = self._grad_sparse_indices_deduplicate(gradients) - gradients = self.gradients_centralization(gradients) lr = self.get_lr() self.beta1_power = self.beta1_power * self.beta1 diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index 74f4af81c0b..4ef1019ecd8 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -163,8 +163,8 @@ class Momentum(Optimizer): params = self.params moments = self.moments gradients = self.decay_weight(gradients) - gradients = self.scale_grad(gradients) gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments, diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index a3a1a53680c..81218a92927 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -443,10 +443,6 @@ class Optimizer(Cell): self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization']) for param in group_param['params']: validator.check_value_type("parameter", param, [Parameter], self.cls_name) - if "conv" not in param.name and self.grad_centralization is True: - raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter" - "is not a convolution layer, this parameter cannot be set to True.") - grad_centralization_ = self.grad_centralization else: grad_centralization_ = grad_centralization @@ -634,9 +630,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): """Get grad with grad_centralization.""" if if_apply: indices = gradient.indices - values = op_gc(gradient.values, -1) shape = gradient.dense_shape - return RowTensor(indices, values, shape) + grad_shape = F.shape(gradient) + axis = [] + for i in range(1, len(grad_shape)): + axis.append(i) + if len(axis) >= 1: + if grad_shape[1] % 16 != 0: + return gradient + values = op_gc(gradient.values, axis) + return RowTensor(indices, values, shape) return gradient @@ -644,7 +647,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient): def _tensor_apply_grad_centralization(if_apply, gradient): """Get grad with grad_centralization.""" if if_apply: - return op_gc(gradient, -1) + axis = [] + grad_shape = F.shape(gradient) + for i in range(1, len(grad_shape)): + axis.append(i) + if len(axis) >= 1: + if grad_shape[1] % 16 != 0: + return gradient + return op_gc(gradient, axis) return gradient diff --git a/mindspore/nn/optim/proximal_ada_grad.py b/mindspore/nn/optim/proximal_ada_grad.py index 1a181111026..90ace9cc7b5 100644 --- a/mindspore/nn/optim/proximal_ada_grad.py +++ b/mindspore/nn/optim/proximal_ada_grad.py @@ -169,9 +169,9 @@ class ProximalAdagrad(Optimizer): params = self.parameters accum = self.accum grads = self.decay_weight(grads) + grads = self.gradients_centralization(grads) grads = self.scale_grad(grads) grads = self._grad_sparse_indices_deduplicate(grads) - grads = self.gradients_centralization(grads) lr = self.get_lr() if self.is_group_lr: success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr, diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index f50beb14b90..dfd9314aa5e 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -204,8 +204,8 @@ class RMSProp(Optimizer): def construct(self, gradients): params = self.parameters gradients = self.decay_weight(gradients) - gradients = self.scale_grad(gradients) gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) lr = self.get_lr() if self.centered: if self.is_group_lr: diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 99131f31ef6..9aff96c144c 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -176,8 +176,8 @@ class SGD(Optimizer): params = self.parameters accum = self.accum stat = self.stat - gradients = self.scale_grad(gradients) gradients = self.gradients_centralization(gradients) + gradients = self.scale_grad(gradients) lr = self.get_lr() if self.is_group_lr: success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat) diff --git a/mindspore/ops/_op_impl/tbe/centralization.py b/mindspore/ops/_op_impl/tbe/centralization.py index a06a4bab0b9..e1ca3245c82 100644 --- a/mindspore/ops/_op_impl/tbe/centralization.py +++ b/mindspore/ops/_op_impl/tbe/centralization.py @@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \ .attr("axis", "required", "listInt", "all") \ .input(0, "x", False, "required", "all") \ .output(0, "y", False, "required", "all") \ - .op_pattern("reduce") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info()