forked from mindspore-Ecosystem/mindspore
!14809 [Less BN]New add FN, GC optimizer.
From: @linqingke Reviewed-by: @guoqi1024,@xu-yfei Signed-off-by: @xu-yfei
This commit is contained in:
commit
c43deb5469
|
@ -20,6 +20,7 @@ namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
namespace irpass {
|
namespace irpass {
|
||||||
namespace {
|
namespace {
|
||||||
|
enum RemoveNodeType { kOtherNode = 0, kOptimizerNode };
|
||||||
const char kLessBatchNormalizationPassName[] = "less_bn";
|
const char kLessBatchNormalizationPassName[] = "less_bn";
|
||||||
constexpr auto kValidResidualStructureIndex = 1;
|
constexpr auto kValidResidualStructureIndex = 1;
|
||||||
constexpr auto kBNParametersStartIndex = 2;
|
constexpr auto kBNParametersStartIndex = 2;
|
||||||
|
@ -63,6 +64,11 @@ const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||||
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
||||||
|
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
||||||
|
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
||||||
|
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
||||||
|
static std::unordered_map<RemoveNodeType, std::unordered_set<size_t>> kRemoveIndex{
|
||||||
|
{RemoveNodeType::kOtherNode, {2}}, {RemoveNodeType::kOptimizerNode, {3, 5, 6}}};
|
||||||
|
|
||||||
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
|
bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_list) {
|
||||||
if (a == nullptr) {
|
if (a == nullptr) {
|
||||||
|
@ -73,13 +79,56 @@ bool NeedRemove(const ParameterPtr &a, const std::vector<AnfNodePtr> ¶meter_
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsNotRealUseNode(const AnfNodePtr &node) {
|
||||||
|
for (const auto &prim : kNeedRemoveNodeSet) {
|
||||||
|
if (IsPrimitiveCNode(node, prim)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr ConvertRemoveNodeToVirtualNode(const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
std::vector<AnfNodePtr> args;
|
||||||
|
size_t index = 0;
|
||||||
|
const auto &inputs = cnode->inputs();
|
||||||
|
auto remove_index = kRemoveIndex[RemoveNodeType::kOptimizerNode];
|
||||||
|
if (IsPrimitiveCNode(cnode, prim::kPrimLoad) || IsPrimitiveCNode(cnode, prim::kPrimRefToEmbed)) {
|
||||||
|
remove_index = kRemoveIndex[RemoveNodeType::kOtherNode];
|
||||||
|
}
|
||||||
|
|
||||||
|
(void)std::copy_if(
|
||||||
|
inputs.begin(), inputs.end(), std::back_inserter(args),
|
||||||
|
[&remove_index, &index](const AnfNodePtr &) { return remove_index.find(index++) != remove_index.end(); });
|
||||||
|
|
||||||
|
(void)args.insert(args.begin(), NewValueNode(prim::kPrimMakeTuple));
|
||||||
|
const auto &fg = cnode->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(fg);
|
||||||
|
auto new_make_tuple = fg->NewCNode(args);
|
||||||
|
return new_make_tuple;
|
||||||
|
}
|
||||||
|
|
||||||
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
|
bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNodePtr ¶meter) {
|
||||||
auto param_output = manager->node_users().find(parameter);
|
auto param_output = manager->node_users().find(parameter);
|
||||||
if (param_output == manager->node_users().end()) {
|
if (param_output == manager->node_users().end()) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
bool need_remove = true;
|
||||||
|
auto output_info_list = param_output->second;
|
||||||
|
for (const auto &output_info : output_info_list) {
|
||||||
|
const auto &node = output_info.first;
|
||||||
|
if (IsNotRealUseNode(node)) {
|
||||||
|
const auto &cnode = node->cast<CNodePtr>();
|
||||||
|
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode);
|
||||||
|
manager->Replace(cnode, new_cnode);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
need_remove = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return need_remove;
|
||||||
}
|
}
|
||||||
|
|
||||||
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
|
void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager,
|
||||||
|
@ -98,12 +147,20 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager
|
||||||
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); });
|
||||||
|
|
||||||
auto root_parameters = root_graph->parameters();
|
auto root_parameters = root_graph->parameters();
|
||||||
|
size_t origin_param_count = root_parameters.size();
|
||||||
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(),
|
||||||
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
[&real_remove_parameter_list](const AnfNodePtr &node) {
|
||||||
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list);
|
||||||
}),
|
}),
|
||||||
root_parameters.end());
|
root_parameters.end());
|
||||||
|
size_t remove_param_count = origin_param_count - root_parameters.size();
|
||||||
|
size_t hyper_param_count = root_graph->hyper_param_count();
|
||||||
|
if (remove_param_count > hyper_param_count) {
|
||||||
|
MS_LOG(ERROR) << "The number of deleted parameters cannot exceed the number of original parameters.";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
hyper_param_count = hyper_param_count - remove_param_count;
|
||||||
|
root_graph->set_hyper_param_count(hyper_param_count);
|
||||||
manager->SetParameters(root_graph, root_parameters);
|
manager->SetParameters(root_graph, root_parameters);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -231,6 +231,9 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list)
|
||||||
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
std::vector<abstract::AbstractKeywordArgPtr> kwarg_list;
|
||||||
std::vector<size_t> pos_arg_indexes;
|
std::vector<size_t> pos_arg_indexes;
|
||||||
size_t arguments_count = args_spec_list.size();
|
size_t arguments_count = args_spec_list.size();
|
||||||
|
if (hyper_param_count_ > arguments_count) {
|
||||||
|
MS_LOG(EXCEPTION) << "The number of parameters in funcgraph cannot exceed the number of arguments.";
|
||||||
|
}
|
||||||
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
|
for (size_t i = 0; i < arguments_count - hyper_param_count_; i++) {
|
||||||
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
MS_EXCEPTION_IF_NULL(args_spec_list[i]);
|
||||||
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
if (args_spec_list[i]->isa<abstract::AbstractKeywordArg>()) {
|
||||||
|
|
|
@ -13,8 +13,74 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""less batch normalization"""
|
"""less batch normalization"""
|
||||||
|
import numpy as np
|
||||||
|
from mindspore import nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore import Tensor, Parameter
|
||||||
|
from mindspore import dtype as mstype
|
||||||
|
from mindspore.common.initializer import initializer
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["LessBN"]
|
||||||
|
|
||||||
|
|
||||||
|
class CommonHeadLastFN(Cell):
|
||||||
|
r"""
|
||||||
|
The last full normalization layer.
|
||||||
|
|
||||||
|
This layer implements the operation as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{inputs} = \text{norm}(\text{inputs})
|
||||||
|
\text{kernel} = \text{norm}(\text{kernel})
|
||||||
|
\text{outputs} = \text{multiplier} * (\text{inputs} * \text{kernel} + \text{bias}),
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of channels in the input space.
|
||||||
|
out_channels (int): The number of channels in the output space.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||||
|
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||||
|
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||||
|
|
||||||
|
Supported Platforms:
|
||||||
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
|
||||||
|
>>> net = CommonHeadLastFN(3, 4)
|
||||||
|
>>> output = net(input)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros',
|
||||||
|
has_bias=True):
|
||||||
|
super(CommonHeadLastFN, self).__init__()
|
||||||
|
weight_shape = [out_channels, in_channels]
|
||||||
|
self.weight = Parameter(initializer(weight_init, weight_shape), requires_grad=True, name='weight')
|
||||||
|
self.x_norm = P.L2Normalize(axis=1)
|
||||||
|
self.w_norm = P.L2Normalize(axis=1)
|
||||||
|
self.fc = P.MatMul(transpose_a=False, transpose_b=True)
|
||||||
|
self.multiplier = Parameter(Tensor(np.ones([1]), mstype.float32), requires_grad=True, name='multiplier')
|
||||||
|
self.has_bias = has_bias
|
||||||
|
if self.has_bias:
|
||||||
|
bias_shape = [out_channels]
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.bias = Parameter(initializer(bias_init, bias_shape), requires_grad=True, name='bias')
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.x_norm(x)
|
||||||
|
w = self.w_norm(self.weight)
|
||||||
|
x = self.fc(x, w)
|
||||||
|
if self.has_bias:
|
||||||
|
x = self.bias_add(x, self.bias)
|
||||||
|
x = self.multiplier * x
|
||||||
|
return x
|
||||||
|
|
||||||
class LessBN(Cell):
|
class LessBN(Cell):
|
||||||
"""
|
"""
|
||||||
Reduce the number of BN automatically to improve the network performance
|
Reduce the number of BN automatically to improve the network performance
|
||||||
|
@ -31,6 +97,44 @@ class LessBN(Cell):
|
||||||
super(LessBN, self).__init__()
|
super(LessBN, self).__init__()
|
||||||
self.network = network
|
self.network = network
|
||||||
self.network.set_acc("less_bn")
|
self.network.set_acc("less_bn")
|
||||||
|
self.network.update_cell_prefix()
|
||||||
|
self._convert_to_less_bn_net(self.network)
|
||||||
|
self.network.add_flags(defer_inline=True)
|
||||||
|
|
||||||
|
def _convert_dense(self, subcell):
|
||||||
|
"""
|
||||||
|
convert dense cell to FN cell
|
||||||
|
"""
|
||||||
|
prefix = subcell.param_prefix
|
||||||
|
new_subcell = CommonHeadLastFN(subcell.in_channels,
|
||||||
|
subcell.out_channels,
|
||||||
|
subcell.weight,
|
||||||
|
subcell.bias,
|
||||||
|
subcell.has_bias)
|
||||||
|
new_subcell.update_parameters_name(prefix + '.')
|
||||||
|
|
||||||
|
return new_subcell
|
||||||
|
|
||||||
|
def _convert_to_less_bn_net(self, net):
|
||||||
|
"""
|
||||||
|
convert network to less_bn network
|
||||||
|
"""
|
||||||
|
cells = net.name_cells()
|
||||||
|
dense_name = []
|
||||||
|
dense_list = []
|
||||||
|
for name in cells:
|
||||||
|
subcell = cells[name]
|
||||||
|
if subcell == net:
|
||||||
|
continue
|
||||||
|
elif isinstance(subcell, (nn.Dense)):
|
||||||
|
dense_name.append(name)
|
||||||
|
dense_list.append(subcell)
|
||||||
|
else:
|
||||||
|
self._convert_to_less_bn_net(subcell)
|
||||||
|
|
||||||
|
if dense_list:
|
||||||
|
new_subcell = self._convert_dense(dense_list[-1])
|
||||||
|
net.insert_child_to_cell(dense_name[-1], new_subcell)
|
||||||
|
|
||||||
def construct(self, *inputs):
|
def construct(self, *inputs):
|
||||||
return self.network(*inputs)
|
return self.network(*inputs)
|
||||||
|
|
|
@ -1048,7 +1048,7 @@ class Cell(Cell_):
|
||||||
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
Some acceleration algorithms may affect the accuracy of the network, please choose carefully.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
acc_type (:str:`less_bn`): accelerate algorithm.
|
acc_type (str): accelerate algorithm.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If acc_type is not in the algorithm library.
|
ValueError: If acc_type is not in the algorithm library.
|
||||||
|
|
|
@ -157,8 +157,8 @@ class Adagrad(Optimizer):
|
||||||
params = self.parameters
|
params = self.parameters
|
||||||
accum = self.accum
|
accum = self.accum
|
||||||
grads = self.decay_weight(grads)
|
grads = self.decay_weight(grads)
|
||||||
grads = self.scale_grad(grads)
|
|
||||||
grads = self.gradients_centralization(grads)
|
grads = self.gradients_centralization(grads)
|
||||||
|
grads = self.scale_grad(grads)
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if self.is_group_lr:
|
if self.is_group_lr:
|
||||||
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
|
success = self.map_(F.partial(_ada_grad_opt, self.opt), lr, params, accum,
|
||||||
|
|
|
@ -342,9 +342,9 @@ class Adam(Optimizer):
|
||||||
moment1 = self.moment1
|
moment1 = self.moment1
|
||||||
moment2 = self.moment2
|
moment2 = self.moment2
|
||||||
gradients = self.decay_weight(gradients)
|
gradients = self.decay_weight(gradients)
|
||||||
|
gradients = self.gradients_centralization(gradients)
|
||||||
gradients = self.scale_grad(gradients)
|
gradients = self.scale_grad(gradients)
|
||||||
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
||||||
gradients = self.gradients_centralization(gradients)
|
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
|
|
||||||
beta1_power = self.beta1_power * self.beta1
|
beta1_power = self.beta1_power * self.beta1
|
||||||
|
|
|
@ -222,9 +222,9 @@ class FTRL(Optimizer):
|
||||||
moments = self.moments
|
moments = self.moments
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
grads = self.decay_weight(grads)
|
grads = self.decay_weight(grads)
|
||||||
|
grads = self.gradients_centralization(grads)
|
||||||
grads = self.scale_grad(grads)
|
grads = self.scale_grad(grads)
|
||||||
grads = self._grad_sparse_indices_deduplicate(grads)
|
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||||
grads = self.gradients_centralization(grads)
|
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
|
|
||||||
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
success = self.map_(F.partial(_ftrl_opt, self.opt, self.sparse_opt, self._ps_push, self._ps_pull,
|
||||||
|
|
|
@ -258,9 +258,9 @@ class LazyAdam(Optimizer):
|
||||||
|
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
gradients = self.decay_weight(gradients)
|
gradients = self.decay_weight(gradients)
|
||||||
|
gradients = self.gradients_centralization(gradients)
|
||||||
gradients = self.scale_grad(gradients)
|
gradients = self.scale_grad(gradients)
|
||||||
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
gradients = self._grad_sparse_indices_deduplicate(gradients)
|
||||||
gradients = self.gradients_centralization(gradients)
|
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
|
|
||||||
self.beta1_power = self.beta1_power * self.beta1
|
self.beta1_power = self.beta1_power * self.beta1
|
||||||
|
|
|
@ -163,8 +163,8 @@ class Momentum(Optimizer):
|
||||||
params = self.params
|
params = self.params
|
||||||
moments = self.moments
|
moments = self.moments
|
||||||
gradients = self.decay_weight(gradients)
|
gradients = self.decay_weight(gradients)
|
||||||
gradients = self.scale_grad(gradients)
|
|
||||||
gradients = self.gradients_centralization(gradients)
|
gradients = self.gradients_centralization(gradients)
|
||||||
|
gradients = self.scale_grad(gradients)
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if self.is_group_lr:
|
if self.is_group_lr:
|
||||||
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
|
success = self.hyper_map(F.partial(_momentum_opt, self.opt, self.momentum), lr, gradients, params, moments,
|
||||||
|
|
|
@ -443,10 +443,6 @@ class Optimizer(Cell):
|
||||||
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
|
self.grad_centralization = self._preprocess_grad_centralization(group_param['grad_centralization'])
|
||||||
for param in group_param['params']:
|
for param in group_param['params']:
|
||||||
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
validator.check_value_type("parameter", param, [Parameter], self.cls_name)
|
||||||
if "conv" not in param.name and self.grad_centralization is True:
|
|
||||||
raise ValueError("Grad centralization can be perform only on the conv layer. If the parameter"
|
|
||||||
"is not a convolution layer, this parameter cannot be set to True.")
|
|
||||||
|
|
||||||
grad_centralization_ = self.grad_centralization
|
grad_centralization_ = self.grad_centralization
|
||||||
else:
|
else:
|
||||||
grad_centralization_ = grad_centralization
|
grad_centralization_ = grad_centralization
|
||||||
|
@ -634,9 +630,16 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
|
||||||
"""Get grad with grad_centralization."""
|
"""Get grad with grad_centralization."""
|
||||||
if if_apply:
|
if if_apply:
|
||||||
indices = gradient.indices
|
indices = gradient.indices
|
||||||
values = op_gc(gradient.values, -1)
|
|
||||||
shape = gradient.dense_shape
|
shape = gradient.dense_shape
|
||||||
return RowTensor(indices, values, shape)
|
grad_shape = F.shape(gradient)
|
||||||
|
axis = []
|
||||||
|
for i in range(1, len(grad_shape)):
|
||||||
|
axis.append(i)
|
||||||
|
if len(axis) >= 1:
|
||||||
|
if grad_shape[1] % 16 != 0:
|
||||||
|
return gradient
|
||||||
|
values = op_gc(gradient.values, axis)
|
||||||
|
return RowTensor(indices, values, shape)
|
||||||
return gradient
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
@ -644,7 +647,14 @@ def _tensor_apply_grad_centralization_with_sparse(if_apply, gradient):
|
||||||
def _tensor_apply_grad_centralization(if_apply, gradient):
|
def _tensor_apply_grad_centralization(if_apply, gradient):
|
||||||
"""Get grad with grad_centralization."""
|
"""Get grad with grad_centralization."""
|
||||||
if if_apply:
|
if if_apply:
|
||||||
return op_gc(gradient, -1)
|
axis = []
|
||||||
|
grad_shape = F.shape(gradient)
|
||||||
|
for i in range(1, len(grad_shape)):
|
||||||
|
axis.append(i)
|
||||||
|
if len(axis) >= 1:
|
||||||
|
if grad_shape[1] % 16 != 0:
|
||||||
|
return gradient
|
||||||
|
return op_gc(gradient, axis)
|
||||||
return gradient
|
return gradient
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -169,9 +169,9 @@ class ProximalAdagrad(Optimizer):
|
||||||
params = self.parameters
|
params = self.parameters
|
||||||
accum = self.accum
|
accum = self.accum
|
||||||
grads = self.decay_weight(grads)
|
grads = self.decay_weight(grads)
|
||||||
|
grads = self.gradients_centralization(grads)
|
||||||
grads = self.scale_grad(grads)
|
grads = self.scale_grad(grads)
|
||||||
grads = self._grad_sparse_indices_deduplicate(grads)
|
grads = self._grad_sparse_indices_deduplicate(grads)
|
||||||
grads = self.gradients_centralization(grads)
|
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if self.is_group_lr:
|
if self.is_group_lr:
|
||||||
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
|
success = self.map_(F.partial(_proximal_ada_grad_opt, self.opt, self.sparse_opt, self.l1, self.l2), lr,
|
||||||
|
|
|
@ -204,8 +204,8 @@ class RMSProp(Optimizer):
|
||||||
def construct(self, gradients):
|
def construct(self, gradients):
|
||||||
params = self.parameters
|
params = self.parameters
|
||||||
gradients = self.decay_weight(gradients)
|
gradients = self.decay_weight(gradients)
|
||||||
gradients = self.scale_grad(gradients)
|
|
||||||
gradients = self.gradients_centralization(gradients)
|
gradients = self.gradients_centralization(gradients)
|
||||||
|
gradients = self.scale_grad(gradients)
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if self.centered:
|
if self.centered:
|
||||||
if self.is_group_lr:
|
if self.is_group_lr:
|
||||||
|
|
|
@ -176,8 +176,8 @@ class SGD(Optimizer):
|
||||||
params = self.parameters
|
params = self.parameters
|
||||||
accum = self.accum
|
accum = self.accum
|
||||||
stat = self.stat
|
stat = self.stat
|
||||||
gradients = self.scale_grad(gradients)
|
|
||||||
gradients = self.gradients_centralization(gradients)
|
gradients = self.gradients_centralization(gradients)
|
||||||
|
gradients = self.scale_grad(gradients)
|
||||||
lr = self.get_lr()
|
lr = self.get_lr()
|
||||||
if self.is_group_lr:
|
if self.is_group_lr:
|
||||||
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
|
success = self.hyper_map(F.partial(_sgd_opt, self.opt, self.momentum), lr, gradients, params, accum, stat)
|
||||||
|
|
|
@ -26,9 +26,10 @@ centralization_op_info = TBERegOp("Centralization") \
|
||||||
.attr("axis", "required", "listInt", "all") \
|
.attr("axis", "required", "listInt", "all") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.op_pattern("reduce") \
|
|
||||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \
|
||||||
|
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue