!12898 support parameter freeze

From: @zhangbuxue
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-04-23 17:04:16 +08:00 committed by Gitee
commit e3dc9a75cc
10 changed files with 240 additions and 35 deletions

View File

@ -478,10 +478,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
// Update Graph Dynamic Shape Attr
UpdateAllGraphDynamicShapeAttr(all_graphs);
for (const auto &graph : all_graphs) {
UnifyMindIR(graph);
}
BackendOptimization(all_graphs);
UnifyMindIR(root_graph);
opt::BackendCommonOptimization(root_graph);
// empty graph dont entry to backend
if (root_graph->execution_order().empty()) {
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";

View File

@ -19,7 +19,6 @@
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
void CPUE2eDump::DumpCNodeData(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto &dump_json_parser = DumpJsonParser::GetInstance();

View File

@ -92,7 +92,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) {
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element.";
}

View File

@ -17,6 +17,8 @@ Accelerating.
Provide auto accelerating for network, such as Less BN.
"""
from .less_batch_normalization import LessBN
from .less_batch_normalization import *
from .grad_freeze import *
__all__ = ['LessBN']
__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence']

View File

@ -0,0 +1,139 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""grad freeze"""
import numpy as np
from mindspore.nn.cell import Cell
from mindspore.nn.optim import Optimizer
from mindspore.common import Tensor, Parameter
from mindspore.common import dtype as mstype
__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence',
'FreezeOpt']
CONTINUOUS_STRATEGY = 0
INTERVAL_STRATEGY = 1
def split_parameters_groups(net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training."""
grouped_params = []
tmp = []
for para in net.trainable_params():
name = para.name
# ensure 'bn' after 'conv' is not split
if 'bn' in name or 'bias' in name:
tmp.append(para)
elif len(tmp) >= 3:
grouped_params.append(tmp)
tmp = [para]
else:
tmp.append(para)
if tmp:
grouped_params.append(tmp)
stride = len(grouped_params) // freeze_para_groups_number
freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)]
return freeze_grouped_params
def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch):
"""Generate index sequence for gradient freezing training."""
total_step = steps_per_epoch * max_epoch * 1.01
# local continuous freezing training strategy, as '00001234'
if freeze_strategy == CONTINUOUS_STRATEGY:
zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
freeze_idxes = []
while len(freeze_idxes) < total_step:
freeze_idxes += sub_idx
return freeze_idxes
# interval freezing training strategy, as '01020304'
if freeze_strategy == INTERVAL_STRATEGY:
index_all = list(range(1, parameter_groups_number))
prob = [x / sum(index_all) for x in index_all]
freeze_idxes = [0]
zero_cnt = 1
freeze_cnt = 0
while len(freeze_idxes) < total_step:
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
if freeze_p_cur < 1 - freeze_p:
freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob)))
freeze_cnt += 1
else:
freeze_idxes.append(0)
zero_cnt += 1
return freeze_idxes
raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'")
class FreezeOpt(Cell):
"""
Optimizer that supports gradients freezing training.
Args:
opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.
Supported Platforms:
``Ascend``
"""
def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
super(FreezeOpt, self).__init__()
if not isinstance(opt, Optimizer):
raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
if train_strategy is not None and train_parameter_groups is None:
raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
"must also be specified")
opt_class = type(opt)
opt_init_args = opt.init_args
self.opts = []
if train_parameter_groups is None:
groups_num = 10
step = 6
parameters = opt.parameters
para_groups = (parameters[(i * step):] for i in range(groups_num))
self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups]
else:
if not isinstance(train_parameter_groups, (tuple, list)):
raise TypeError("The specified 'train_parameter_groups' should be tuple or list")
for params in train_parameter_groups:
if not isinstance(params, (tuple, list)):
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
"to store the Parameter")
for para in params:
if not isinstance(para, Parameter):
raise TypeError("The element of each group should be the Parameter")
# generate one-to-one opt corresponding to the parameter group
self.opts.append(opt_class(params=params, **opt_init_args))
if isinstance(train_strategy, (tuple, list)):
for ele in train_strategy:
if not isinstance(ele, int):
raise ValueError("The element in train_strategy should be int number")
self.train_strategy = Tensor(train_strategy, mstype.int32)
elif isinstance(train_strategy, Tensor):
if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and "
"the dtype should be int32")
self.train_strategy = train_strategy
elif train_strategy is None:
self.train_strategy = None
else:
raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor")

View File

@ -14,12 +14,12 @@
# ============================================================================
"""less Batch Normalization"""
import numpy as np
from mindspore import nn
from mindspore.nn.cell import Cell
from mindspore.nn.layer import Dense
from mindspore.ops import operations as P
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.common import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.common.initializer import initializer
from ..cell import Cell
__all__ = ["LessBN"]
@ -126,7 +126,7 @@ class LessBN(Cell):
subcell = cells[name]
if subcell == net:
continue
elif isinstance(subcell, (nn.Dense)):
elif isinstance(subcell, (Dense)):
dense_name.append(name)
dense_list.append(subcell)
else:

View File

@ -19,6 +19,7 @@ from mindspore.common.tensor import Tensor
import mindspore.common.dtype as mstype
from mindspore._checkparam import Validator
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
@ -146,6 +147,7 @@ class Momentum(Optimizer):
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
"""
@opt_init_args_register
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
Validator.check_value_type("momentum", momentum, [float], self.cls_name)

View File

@ -13,6 +13,7 @@
# limitations under the License.
# ============================================================================
"""optimizer"""
import inspect
from typing import Iterable
import numpy as np
@ -33,7 +34,18 @@ from mindspore.context import ParallelMode
from mindspore import context
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
__all__ = ['Optimizer']
__all__ = ['Optimizer', 'opt_init_args_register']
def opt_init_args_register(fn):
def deco(self, *args, **kwargs):
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
arguments.pop('self')
arguments.pop('params')
setattr(self, 'init_args', arguments)
fn(self, *args, **kwargs)
return deco
class Optimizer(Cell):

View File

@ -27,6 +27,7 @@ from ...ops import functional as F
from ...ops import operations as P
from ...ops.operations.comm_ops import _VirtualDataset
from ..cell import Cell
from ...nn import acc
from .grad_reducer import DistributedGradReducer
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
@ -292,6 +293,32 @@ class ForwardValueAndGrad(Cell):
return loss, grads
class _TrainFreezeCell(Cell):
"""Gradient freezing training network."""
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer):
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
self.net = net
self.grad = grad
self.grad_reducer = grad_reducer
self.opt = optimizer
self.parameters = optimizer.parameters
self.sens = sens
self.use_grad_accumulation = use_grad_accumulation
if use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer)
def construct(self, *inputs):
loss = self.net(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.net, self.parameters)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.opt(grads))
return loss
class TrainOneStepCell(Cell):
r"""
Network training package class.
@ -302,7 +329,7 @@ class TrainOneStepCell(Cell):
Args:
network (Cell): The training network. The network only supports single output.
optimizer (Cell): Optimizer for updating the weights.
optimizer (Union[Cell]): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs:
@ -348,41 +375,66 @@ class TrainOneStepCell(Cell):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.freeze = isinstance(optimizer, acc.FreezeOpt)
self.optimizer = optimizer
if not self.freeze:
self.weights = self.optimizer.parameters
self.train_strategy = getattr(self.optimizer, 'train_strategy', None)
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = F.identity
self.parallel_mode = _get_parallel_mode()
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
self.use_grad_accumulation = False
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
self.use_grad_accumulation = True
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE)
if self.use_grad_accumulation:
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
if self.max_accumulation_step <= 1:
self.max_accumulation_step = 1
self.use_grad_accumulation = False
self.grad_accumulation = None
if self.use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
if self.freeze:
self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree)
for opt in self.optimizer.opts)
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for reducer, opt in zip(self.grad_reducers, self.optimizer))
else:
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree)
else:
if self.freeze:
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for opt in self.optimizer.opts)
self.step = Parameter(Tensor(0, dtype=mstype.int32))
def construct(self, *inputs):
weights = self.weights
loss = self.network(*inputs)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
if self.freeze:
if self.train_strategy is None:
step = self.step
max_index = len(self.freeze_nets)
else:
step = self.train_strategy[self.step]
max_index = len(self.train_strategy)
loss = self.freeze_nets[step](*inputs)
if self.step + 1 >= max_index:
self.step = 0
else:
self.step += 1
else:
loss = F.depend(loss, self.optimizer(grads))
loss = self.network(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.network, self.weights)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.optimizer(grads))
return loss

View File

@ -19,6 +19,7 @@ from .. import nn
from .._checkparam import Validator as validator
from .._checkparam import Rel
from ..common import dtype as mstype
from ..nn import acc
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..ops import functional as F
from ..parallel._utils import _get_parallel_mode
@ -139,7 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
scale the loss by `LossScaleManager`. If set, overwrite the level setting.
"""
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, nn.Optimizer)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt))
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN)
if level == "auto":