!14809 [Less BN]New add FN, GC optimizer.

From: @linqingke
Reviewed-by: @guoqi1024,@xu-yfei
Signed-off-by: @xu-yfei
This commit is contained in:
mindspore-ci-bot 2021-04-15 11:17:51 +08:00 committed by Gitee
commit c43deb5469
14 changed files with 194 additions and 19 deletions

View File

@ -20,6 +20,7 @@ namespace mindspore {
namespace opt {
namespace irpass {
namespace {
enum RemoveNodeType { kOtherNode = 0, kOptimizerNode };
const char kLessBatchNormalizationPassName[] = "less_bn";
constexpr auto kValidResidualStructureIndex = 1;
constexpr auto kBNParametersStartIndex = 2;
@ -63,6 +64,11 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{
{RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> &parameter_list) {
if (a == nullptr) {
@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> &parameter_
});
}
bool IsNotRealUseNode(const AnfNodePtr &node) {
for (const auto &prim : kNeedRemoveNodeSet) {
if (IsPrimitiveCNode(node, prim)) {
return true;
}
}
return false;
}
CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> args;
size_t index = 0;
const auto &inputs = cnode->inputs();
auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
}
(void)std::copy_if(
inputs.begin(), inputs.end(), std::back_inserter(args),
[&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
(void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
const auto &fg = cnode->func_graph();
MS_EXCEPTION_IF_NULL(fg);
auto new_make_tuple = fg->NewCNode(args);
return new_make_tuple;
}
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &parameter) {
auto param_output = manager->node_users().find(parameter);
if (param_output == manager->node_users().end()) {
return true;
}
return false;
bool need_remove = true;
auto output_info_list = param_output->second;
for (const auto &output_info : output_info_list) {
const auto &node = output_info.first;
if (IsNotRealUseNode(node)) {
const auto &cnode = node->cast<CNodePtr>();
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
manager->Replace(cnode, new_cnode);
continue;
}
need_remove = false;
}
return need_remove;
}
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
[&manager](const AnfNodePtr &param) { return IsRealRemoveParameterNode(manager, param); });
auto root_parameters = root_graph->parameters();
size_t origin_param_count = root_parameters.size();
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
[&real_remove_parameter_list](const AnfNodePtr &node) {
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
}),
root_parameters.end());
size_t remove_param_count = origin_param_count - root_parameters.size();
size_t hyper_param_count = root_graph->hyper_param_count();
if (remove_param_count > hyper_param_count) {
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
return;
}
hyper_param_count = hyper_param_count - remove_param_count;
root_graph->set_hyper_param_count(hyper_param_count);
manager->SetParameters(root_graph, root_parameters);
}
} // namespace

View File

@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
std::vector<size_t> pos_arg_indexes;
size_t arguments_count = args_spec_list.size();
if (hyper_param_count_ > arguments_count) {
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
}
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {

View File

@ -13,8 +13,74 @@
# limitations under the License.
# ============================================================================
"""less batch normalization"""
import numpy as np
from mindspore import nn
from mindspore.ops import operations as P
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.common.initializer import initializer
from ..cell import Cell
__all__ = ["LessBN"]
class CommonHeadLastFN(Cell):
r"""
The last full normalization layer.
This layer implements the operation as:
.. math::
\text{inputs} = \text{norm}(\text{inputs})
\text{kernel} = \text{norm}(\text{kernel})
\text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
Supported Platforms:
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
>>> net = CommonHeadLastFN(3, 4)
>>> output = net(input)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True):
super(CommonHeadLastFN, self).__init__()
weight_shape = [out_channels, in_channels]
self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
self.x_norm = P.L2Normalize(axis=1)
self.w_norm = P.L2Normalize(axis=1)
self.fc = P.MatMul(transpose_a=False, transpose_b=True)
self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
self.has_bias = has_bias
if self.has_bias:
bias_shape = [out_channels]
self.bias_add = P.BiasAdd()
self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
def construct(self, x):
x = self.x_norm(x)
w = self.w_norm(self.weight)
x = self.fc(x, w)
if self.has_bias:
x = self.bias_add(x, self.bias)
x = self.multiplier * x
return x
class LessBN(Cell):
"""
Reduce the number of BN automatically to improve the network performance
@ -31,6 +97,44 @@ class LessBN(Cell):
super(LessBN, self).__init__()
self.network = network
self.network.set_acc("less_bn")
self.network.update_cell_prefix()
self._convert_to_less_bn_net(self.network)
self.network.add_flags(defer_inline=True)
def _convert_dense(self, subcell):
"""
convert dense cell to FN cell
"""
prefix = subcell.param_prefix
new_subcell = CommonHeadLastFN(subcell.in_channels,
subcell.out_channels,
subcell.weight,
subcell.bias,
subcell.has_bias)
new_subcell.update_parameters_name(prefix + '.')
return new_subcell
def _convert_to_less_bn_net(self, net):
"""
convert network to less_bn network
"""
cells = net.name_cells()
dense_name = []
dense_list = []
for name in cells:
subcell = cells[name]
if subcell == net:
continue
elif isinstance(subcell, (nn.Dense)):
dense_name.append(name)
dense_list.append(subcell)
else:
self._convert_to_less_bn_net(subcell)
if dense_list:
new_subcell = self._convert_dense(dense_list[-1])
net.insert_child_to_cell(dense_name[-1], new_subcell)
def construct(self, *inputs):
return self.network(*inputs)

View File

@ -1048,7 +1048,7 @@ class Cell(Cell_):
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
Args:
acc_type (:str:`less_bn`): accelerate algorithm.
acc_type (str): accelerate algorithm.
Raises:
ValueError: If acc_type is not in the algorithm library.

View File

@ -157,8 +157,8 @@ class Adagrad(Optimizer):
params = self.parameters
accum = self.accum
grads = self.decay_weight(grads)
grads = self.scale_grad(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,

View File

@ -342,9 +342,9 @@ class Adam(Optimizer):
moment1 = self.moment1
moment2 = self.moment2
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr()
beta1_power = self.beta1_power * self.beta1

View File

@ -222,9 +222,9 @@ class FTRL(Optimizer):
moments = self.moments
linear = self.linear
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads)
grads = self.gradients_centralization(grads)
lr = self.get_lr()
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,

View File

@ -258,9 +258,9 @@ class LazyAdam(Optimizer):
def construct(self, gradients):
gradients = self.decay_weight(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients)
gradients = self.gradients_centralization(gradients)
lr = self.get_lr()
self.beta1_power = self.beta1_power * self.beta1

View File

@ -163,8 +163,8 @@ class Momentum(Optimizer):
params = self.params
moments = self.moments
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,

View File

@ -443,10 +443,6 @@ class Optimizer(Cell):
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
for param in group_param['params']:
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
if "conv" not in param.name and self.grad_centralization is True:
raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter"
"is not a convolution layer, this parameter cannot be set to True.")
grad_centralization_ = self.grad_centralization
else:
grad_centralization_ = grad_centralization
@ -634,9 +630,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
"""Get grad with grad_centralization."""
if if_apply:
indices = gradient.indices
values = op_gc(gradient.values, -1)
shape = gradient.dense_shape
return RowTensor(indices, values, shape)
grad_shape = F.shape(gradient)
axis = []
for i in range(1, len(grad_shape)):
axis.append(i)
if len(axis) >= 1:
if grad_shape[1] % 16 != 0:
return gradient
values = op_gc(gradient.values, axis)
return RowTensor(indices, values, shape)
return gradient
@ -644,7 +647,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
def _tensor_apply_grad_centralization(if_apply, gradient):
"""Get grad with grad_centralization."""
if if_apply:
return op_gc(gradient, -1)
axis = []
grad_shape = F.shape(gradient)
for i in range(1, len(grad_shape)):
axis.append(i)
if len(axis) >= 1:
if grad_shape[1] % 16 != 0:
return gradient
return op_gc(gradient, axis)
return gradient

View File

@ -169,9 +169,9 @@ class ProximalAdagrad(Optimizer):
params = self.parameters
accum = self.accum
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)
grads = self.scale_grad(grads)
grads = self._grad_sparse_indices_deduplicate(grads)
grads = self.gradients_centralization(grads)
lr = self.get_lr()
if self.is_group_lr:
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,

View File

@ -204,8 +204,8 @@ class RMSProp(Optimizer):
def construct(self, gradients):
params = self.parameters
gradients = self.decay_weight(gradients)
gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.centered:
if self.is_group_lr:

View File

@ -176,8 +176,8 @@ class SGD(Optimizer):
params = self.parameters
accum = self.accum
stat = self.stat
gradients = self.scale_grad(gradients)
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
lr = self.get_lr()
if self.is_group_lr:
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)

View File

@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \
.attr("axis", "required", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.op_pattern("reduce") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
.get_op_info()