forked from mindspore-Ecosystem/mindspore
!14809 [Less BN]New add FN, GC optimizer.
From: @linqingke Reviewed-by: @guoqi1024,@xu-yfei Signed-off-by: @xu-yfei
This commit is contained in:
commit
c43deb5469
|
@ -20,6 +20,7 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace irpass {
|
||||
namespace {
|
||||
enum RemoveNodeType { kOtherNode = 0, kOptimizerNode };
|
||||
const char kLessBatchNormalizationPassName[] = "less_bn";
|
||||
constexpr auto kValidResidualStructureIndex = 1;
|
||||
constexpr auto kBNParametersStartIndex = 2;
|
||||
|
@ -63,6 +64,11 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
|||
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
||||
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
||||
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
||||
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
||||
static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{
|
||||
{RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
|
||||
|
||||
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
|
||||
if (a == nullptr) {
|
||||
|
@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_
|
|||
});
|
||||
}
|
||||
|
||||
bool IsNotRealUseNode(const AnfNodePtr &node) {
|
||||
for (const auto &prim : kNeedRemoveNodeSet) {
|
||||
if (IsPrimitiveCNode(node, prim)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::vector<AnfNodePtr> args;
|
||||
size_t index = 0;
|
||||
const auto &inputs = cnode->inputs();
|
||||
auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
|
||||
remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
|
||||
}
|
||||
|
||||
(void)std::copy_if(
|
||||
inputs.begin(), inputs.end(), std::back_inserter(args),
|
||||
[&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
|
||||
|
||||
(void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||
const auto &fg = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto new_make_tuple = fg->NewCNode(args);
|
||||
return new_make_tuple;
|
||||
}
|
||||
|
||||
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
|
||||
auto param_output = manager->node_users().find(parameter);
|
||||
if (param_output == manager->node_users().end()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
bool need_remove = true;
|
||||
auto output_info_list = param_output->second;
|
||||
for (const auto &output_info : output_info_list) {
|
||||
const auto &node = output_info.first;
|
||||
if (IsNotRealUseNode(node)) {
|
||||
const auto &cnode = node->cast<CNodePtr>();
|
||||
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
|
||||
manager->Replace(cnode, new_cnode);
|
||||
continue;
|
||||
}
|
||||
need_remove = false;
|
||||
}
|
||||
|
||||
return need_remove;
|
||||
}
|
||||
|
||||
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
|
||||
|
@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
|
|||
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||
|
||||
auto root_parameters = root_graph->parameters();
|
||||
size_t origin_param_count = root_parameters.size();
|
||||
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
||||
}),
|
||||
root_parameters.end());
|
||||
|
||||
size_t remove_param_count = origin_param_count - root_parameters.size();
|
||||
size_t hyper_param_count = root_graph->hyper_param_count();
|
||||
if (remove_param_count > hyper_param_count) {
|
||||
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
|
||||
return;
|
||||
}
|
||||
hyper_param_count = hyper_param_count - remove_param_count;
|
||||
root_graph->set_hyper_param_count(hyper_param_count);
|
||||
manager->SetParameters(root_graph, root_parameters);
|
||||
}
|
||||
} // namespace
|
||||
|
|
|
@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
|||
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
||||
std::vector<size_t> pos_arg_indexes;
|
||||
size_t arguments_count = args_spec_list.size();
|
||||
if (hyper_param_count_ > arguments_count) {
|
||||
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
|
||||
}
|
||||
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
||||
|
|
|
@ -13,8 +13,74 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""less batch normalization"""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from ..cell import Cell
|
||||
|
||||
|
||||
__all__ = ["LessBN"]
|
||||
|
||||
|
||||
class CommonHeadLastFN(Cell):
|
||||
r"""
|
||||
The last full normalization layer.
|
||||
|
||||
This layer implements the operation as:
|
||||
|
||||
.. math::
|
||||
\text{inputs} = \text{norm}(\text{inputs})
|
||||
\text{kernel} = \text{norm}(\text{kernel})
|
||||
\text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
||||
>>> net = CommonHeadLastFN(3, 4)
|
||||
>>> output = net(input)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True):
|
||||
super(CommonHeadLastFN, self).__init__()
|
||||
weight_shape = [out_channels, in_channels]
|
||||
self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
|
||||
self.x_norm = P.L2Normalize(axis=1)
|
||||
self.w_norm = P.L2Normalize(axis=1)
|
||||
self.fc = P.MatMul(transpose_a=False, transpose_b=True)
|
||||
self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
|
||||
self.has_bias = has_bias
|
||||
if self.has_bias:
|
||||
bias_shape = [out_channels]
|
||||
self.bias_add = P.BiasAdd()
|
||||
self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
x = self.x_norm(x)
|
||||
w = self.w_norm(self.weight)
|
||||
x = self.fc(x, w)
|
||||
if self.has_bias:
|
||||
x = self.bias_add(x, self.bias)
|
||||
x = self.multiplier * x
|
||||
return x
|
||||
|
||||
class LessBN(Cell):
|
||||
"""
|
||||
Reduce the number of BN automatically to improve the network performance
|
||||
|
@ -31,6 +97,44 @@ class LessBN(Cell):
|
|||
super(LessBN, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_acc("less_bn")
|
||||
self.network.update_cell_prefix()
|
||||
self._convert_to_less_bn_net(self.network)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
def _convert_dense(self, subcell):
|
||||
"""
|
||||
convert dense cell to FN cell
|
||||
"""
|
||||
prefix = subcell.param_prefix
|
||||
new_subcell = CommonHeadLastFN(subcell.in_channels,
|
||||
subcell.out_channels,
|
||||
subcell.weight,
|
||||
subcell.bias,
|
||||
subcell.has_bias)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
|
||||
return new_subcell
|
||||
|
||||
def _convert_to_less_bn_net(self, net):
|
||||
"""
|
||||
convert network to less_bn network
|
||||
"""
|
||||
cells = net.name_cells()
|
||||
dense_name = []
|
||||
dense_list = []
|
||||
for name in cells:
|
||||
subcell = cells[name]
|
||||
if subcell == net:
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Dense)):
|
||||
dense_name.append(name)
|
||||
dense_list.append(subcell)
|
||||
else:
|
||||
self._convert_to_less_bn_net(subcell)
|
||||
|
||||
if dense_list:
|
||||
new_subcell = self._convert_dense(dense_list[-1])
|
||||
net.insert_child_to_cell(dense_name[-1], new_subcell)
|
||||
|
||||
def construct(self, *inputs):
|
||||
return self.network(*inputs)
|
||||
|
|
|
@ -1048,7 +1048,7 @@ class Cell(Cell_):
|
|||
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
||||
|
||||
Args:
|
||||
acc_type (:str:`less_bn`): accelerate algorithm.
|
||||
acc_type (str): accelerate algorithm.
|
||||
|
||||
Raises:
|
||||
ValueError: If acc_type is not in the algorithm library.
|
||||
|
|
|
@ -157,8 +157,8 @@ class Adagrad(Optimizer):
|
|||
params = self.parameters
|
||||
accum = self.accum
|
||||
grads = self.decay_weight(grads)
|
||||
grads = self.scale_grad(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
grads = self.scale_grad(grads)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
|
||||
|
|
|
@ -342,9 +342,9 @@ class Adam(Optimizer):
|
|||
moment1 = self.moment1
|
||||
moment2 = self.moment2
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
beta1_power = self.beta1_power * self.beta1
|
||||
|
|
|
@ -222,9 +222,9 @@ class FTRL(Optimizer):
|
|||
moments = self.moments
|
||||
linear = self.linear
|
||||
grads = self.decay_weight(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
grads = self.scale_grad(grads)
|
||||
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
lr = self.get_lr()
|
||||
|
||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||
|
|
|
@ -258,9 +258,9 @@ class LazyAdam(Optimizer):
|
|||
|
||||
def construct(self, gradients):
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
lr = self.get_lr()
|
||||
|
||||
self.beta1_power = self.beta1_power * self.beta1
|
||||
|
|
|
@ -163,8 +163,8 @@ class Momentum(Optimizer):
|
|||
params = self.params
|
||||
moments = self.moments
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
|
||||
|
|
|
@ -443,10 +443,6 @@ class Optimizer(Cell):
|
|||
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
|
||||
for param in group_param['params']:
|
||||
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
||||
if "conv" not in param.name and self.grad_centralization is True:
|
||||
raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter"
|
||||
"is not a convolution layer, this parameter cannot be set to True.")
|
||||
|
||||
grad_centralization_ = self.grad_centralization
|
||||
else:
|
||||
grad_centralization_ = grad_centralization
|
||||
|
@ -634,9 +630,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
|
|||
"""Get grad with grad_centralization."""
|
||||
if if_apply:
|
||||
indices = gradient.indices
|
||||
values = op_gc(gradient.values, -1)
|
||||
shape = gradient.dense_shape
|
||||
return RowTensor(indices, values, shape)
|
||||
grad_shape = F.shape(gradient)
|
||||
axis = []
|
||||
for i in range(1, len(grad_shape)):
|
||||
axis.append(i)
|
||||
if len(axis) >= 1:
|
||||
if grad_shape[1] % 16 != 0:
|
||||
return gradient
|
||||
values = op_gc(gradient.values, axis)
|
||||
return RowTensor(indices, values, shape)
|
||||
return gradient
|
||||
|
||||
|
||||
|
@ -644,7 +647,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
|
|||
def _tensor_apply_grad_centralization(if_apply, gradient):
|
||||
"""Get grad with grad_centralization."""
|
||||
if if_apply:
|
||||
return op_gc(gradient, -1)
|
||||
axis = []
|
||||
grad_shape = F.shape(gradient)
|
||||
for i in range(1, len(grad_shape)):
|
||||
axis.append(i)
|
||||
if len(axis) >= 1:
|
||||
if grad_shape[1] % 16 != 0:
|
||||
return gradient
|
||||
return op_gc(gradient, axis)
|
||||
return gradient
|
||||
|
||||
|
||||
|
|
|
@ -169,9 +169,9 @@ class ProximalAdagrad(Optimizer):
|
|||
params = self.parameters
|
||||
accum = self.accum
|
||||
grads = self.decay_weight(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
grads = self.scale_grad(grads)
|
||||
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
|
||||
|
|
|
@ -204,8 +204,8 @@ class RMSProp(Optimizer):
|
|||
def construct(self, gradients):
|
||||
params = self.parameters
|
||||
gradients = self.decay_weight(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.centered:
|
||||
if self.is_group_lr:
|
||||
|
|
|
@ -176,8 +176,8 @@ class SGD(Optimizer):
|
|||
params = self.parameters
|
||||
accum = self.accum
|
||||
stat = self.stat
|
||||
gradients = self.scale_grad(gradients)
|
||||
gradients = self.gradients_centralization(gradients)
|
||||
gradients = self.scale_grad(gradients)
|
||||
lr = self.get_lr()
|
||||
if self.is_group_lr:
|
||||
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
|
||||
|
|
|
@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \
|
|||
.attr("axis", "required", "listInt", "all") \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.op_pattern("reduce") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue