forked from mindspore-Ecosystem/mindspore
!12898 support parameter freeze
From: @zhangbuxue Reviewed-by: Signed-off-by:
This commit is contained in:
commit
e3dc9a75cc
|
@ -478,10 +478,8 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
|||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||
// Update Graph Dynamic Shape Attr
|
||||
UpdateAllGraphDynamicShapeAttr(all_graphs);
|
||||
for (const auto &graph : all_graphs) {
|
||||
UnifyMindIR(graph);
|
||||
}
|
||||
BackendOptimization(all_graphs);
|
||||
UnifyMindIR(root_graph);
|
||||
opt::BackendCommonOptimization(root_graph);
|
||||
// empty graph dont entry to backend
|
||||
if (root_graph->execution_order().empty()) {
|
||||
MS_LOG(INFO) << root_graph->ToString() << " is empty graph.";
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
void CPUE2eDump::DumpCNodeData(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto &dump_json_parser = DumpJsonParser::GetInstance();
|
||||
|
|
|
@ -92,7 +92,7 @@ AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitiveP
|
|||
|
||||
for (size_t i = 0; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
if (!branches[i]->isa<FuncGraphAbstractClosure>()) {
|
||||
if (!branches[i]->isa<FuncGraphAbstractClosure>() && !branches[i]->isa<PartialAbstractClosure>()) {
|
||||
MS_EXCEPTION(ValueError) << op_name << " requires that the 2th arg be tuple of functions, but got "
|
||||
<< branches[i]->ToString() << " as the " << i << "th element.";
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@ Accelerating.
|
|||
|
||||
Provide auto accelerating for network, such as Less BN.
|
||||
"""
|
||||
from .less_batch_normalization import LessBN
|
||||
from .less_batch_normalization import *
|
||||
from .grad_freeze import *
|
||||
|
||||
__all__ = ['LessBN']
|
||||
__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
|
||||
'split_parameters_groups', 'generate_freeze_index_sequence']
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""grad freeze"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.optim import Optimizer
|
||||
from mindspore.common import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
|
||||
'split_parameters_groups', 'generate_freeze_index_sequence',
|
||||
'FreezeOpt']
|
||||
|
||||
CONTINUOUS_STRATEGY = 0
|
||||
INTERVAL_STRATEGY = 1
|
||||
|
||||
|
||||
def split_parameters_groups(net, freeze_para_groups_number):
|
||||
"""Split parameter groups for gradients freezing training."""
|
||||
grouped_params = []
|
||||
tmp = []
|
||||
for para in net.trainable_params():
|
||||
name = para.name
|
||||
# ensure 'bn' after 'conv' is not split
|
||||
if 'bn' in name or 'bias' in name:
|
||||
tmp.append(para)
|
||||
elif len(tmp) >= 3:
|
||||
grouped_params.append(tmp)
|
||||
tmp = [para]
|
||||
else:
|
||||
tmp.append(para)
|
||||
if tmp:
|
||||
grouped_params.append(tmp)
|
||||
stride = len(grouped_params) // freeze_para_groups_number
|
||||
freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)]
|
||||
return freeze_grouped_params
|
||||
|
||||
|
||||
def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch):
|
||||
"""Generate index sequence for gradient freezing training."""
|
||||
total_step = steps_per_epoch * max_epoch * 1.01
|
||||
# local continuous freezing training strategy, as '00001234'
|
||||
if freeze_strategy == CONTINUOUS_STRATEGY:
|
||||
zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
|
||||
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
|
||||
freeze_idxes = []
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_idxes += sub_idx
|
||||
return freeze_idxes
|
||||
# interval freezing training strategy, as '01020304'
|
||||
if freeze_strategy == INTERVAL_STRATEGY:
|
||||
index_all = list(range(1, parameter_groups_number))
|
||||
prob = [x / sum(index_all) for x in index_all]
|
||||
freeze_idxes = [0]
|
||||
zero_cnt = 1
|
||||
freeze_cnt = 0
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
|
||||
if freeze_p_cur < 1 - freeze_p:
|
||||
freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob)))
|
||||
freeze_cnt += 1
|
||||
else:
|
||||
freeze_idxes.append(0)
|
||||
zero_cnt += 1
|
||||
return freeze_idxes
|
||||
raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'")
|
||||
|
||||
|
||||
class FreezeOpt(Cell):
|
||||
"""
|
||||
Optimizer that supports gradients freezing training.
|
||||
|
||||
Args:
|
||||
opt (Optimizer): non-freezing optimizer instance, such as 'Momentum', 'SGD'.
|
||||
train_parameter_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
|
||||
train_strategy (Union[tuple(int), list(int), Tensor]): Strategy for gradients freezing training.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
|
||||
super(FreezeOpt, self).__init__()
|
||||
if not isinstance(opt, Optimizer):
|
||||
raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
|
||||
if train_strategy is not None and train_parameter_groups is None:
|
||||
raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
|
||||
"must also be specified")
|
||||
opt_class = type(opt)
|
||||
opt_init_args = opt.init_args
|
||||
self.opts = []
|
||||
|
||||
if train_parameter_groups is None:
|
||||
groups_num = 10
|
||||
step = 6
|
||||
parameters = opt.parameters
|
||||
para_groups = (parameters[(i * step):] for i in range(groups_num))
|
||||
self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups]
|
||||
else:
|
||||
if not isinstance(train_parameter_groups, (tuple, list)):
|
||||
raise TypeError("The specified 'train_parameter_groups' should be tuple or list")
|
||||
for params in train_parameter_groups:
|
||||
if not isinstance(params, (tuple, list)):
|
||||
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
|
||||
"to store the Parameter")
|
||||
for para in params:
|
||||
if not isinstance(para, Parameter):
|
||||
raise TypeError("The element of each group should be the Parameter")
|
||||
|
||||
# generate one-to-one opt corresponding to the parameter group
|
||||
self.opts.append(opt_class(params=params, **opt_init_args))
|
||||
|
||||
if isinstance(train_strategy, (tuple, list)):
|
||||
for ele in train_strategy:
|
||||
if not isinstance(ele, int):
|
||||
raise ValueError("The element in train_strategy should be int number")
|
||||
self.train_strategy = Tensor(train_strategy, mstype.int32)
|
||||
elif isinstance(train_strategy, Tensor):
|
||||
if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
|
||||
raise ValueError("When train_strategy is a Tensor, the dimension should be 1 and "
|
||||
"the dtype should be int32")
|
||||
self.train_strategy = train_strategy
|
||||
elif train_strategy is None:
|
||||
self.train_strategy = None
|
||||
else:
|
||||
raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor")
|
|
@ -14,12 +14,12 @@
|
|||
# ============================================================================
|
||||
"""less Batch Normalization"""
|
||||
import numpy as np
|
||||
from mindspore import nn
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.layer import Dense
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.common import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from ..cell import Cell
|
||||
|
||||
|
||||
__all__ = ["LessBN"]
|
||||
|
@ -126,7 +126,7 @@ class LessBN(Cell):
|
|||
subcell = cells[name]
|
||||
if subcell == net:
|
||||
continue
|
||||
elif isinstance(subcell, (nn.Dense)):
|
||||
elif isinstance(subcell, (Dense)):
|
||||
dense_name.append(name)
|
||||
dense_list.append(subcell)
|
||||
else:
|
||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.common.tensor import Tensor
|
|||
import mindspore.common.dtype as mstype
|
||||
from mindspore._checkparam import Validator
|
||||
from .optimizer import Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
_momentum_opt = C.MultitypeFuncGraph("momentum_opt")
|
||||
|
||||
|
@ -146,6 +147,7 @@ class Momentum(Optimizer):
|
|||
>>> loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
>>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None)
|
||||
"""
|
||||
@opt_init_args_register
|
||||
def __init__(self, params, learning_rate, momentum, weight_decay=0.0, loss_scale=1.0, use_nesterov=False):
|
||||
super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale)
|
||||
Validator.check_value_type("momentum", momentum, [float], self.cls_name)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""optimizer"""
|
||||
import inspect
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
@ -33,7 +34,18 @@ from mindspore.context import ParallelMode
|
|||
from mindspore import context
|
||||
from mindspore.nn.learning_rate_schedule import LearningRateSchedule
|
||||
|
||||
__all__ = ['Optimizer']
|
||||
__all__ = ['Optimizer', 'opt_init_args_register']
|
||||
|
||||
def opt_init_args_register(fn):
|
||||
def deco(self, *args, **kwargs):
|
||||
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
arguments = bound_args.arguments
|
||||
arguments.pop('self')
|
||||
arguments.pop('params')
|
||||
setattr(self, 'init_args', arguments)
|
||||
fn(self, *args, **kwargs)
|
||||
return deco
|
||||
|
||||
|
||||
class Optimizer(Cell):
|
||||
|
|
|
@ -27,6 +27,7 @@ from ...ops import functional as F
|
|||
from ...ops import operations as P
|
||||
from ...ops.operations.comm_ops import _VirtualDataset
|
||||
from ..cell import Cell
|
||||
from ...nn import acc
|
||||
from .grad_reducer import DistributedGradReducer
|
||||
|
||||
_get_datatype = C.MultitypeFuncGraph("_get_datatype")
|
||||
|
@ -292,6 +293,32 @@ class ForwardValueAndGrad(Cell):
|
|||
return loss, grads
|
||||
|
||||
|
||||
class _TrainFreezeCell(Cell):
|
||||
"""Gradient freezing training network."""
|
||||
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer):
|
||||
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
|
||||
self.net = net
|
||||
self.grad = grad
|
||||
self.grad_reducer = grad_reducer
|
||||
self.opt = optimizer
|
||||
self.parameters = optimizer.parameters
|
||||
self.sens = sens
|
||||
self.use_grad_accumulation = use_grad_accumulation
|
||||
if use_grad_accumulation:
|
||||
self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer)
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.net(*inputs)
|
||||
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
||||
grads = self.grad(self.net, self.parameters)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.grad_accumulation(loss, grads)
|
||||
else:
|
||||
loss = F.depend(loss, self.opt(grads))
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
@ -302,7 +329,7 @@ class TrainOneStepCell(Cell):
|
|||
|
||||
Args:
|
||||
network (Cell): The training network. The network only supports single output.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
optimizer (Union[Cell]): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Inputs:
|
||||
|
@ -348,41 +375,66 @@ class TrainOneStepCell(Cell):
|
|||
super(TrainOneStepCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.freeze = isinstance(optimizer, acc.FreezeOpt)
|
||||
self.optimizer = optimizer
|
||||
if not self.freeze:
|
||||
self.weights = self.optimizer.parameters
|
||||
self.train_strategy = getattr(self.optimizer, 'train_strategy', None)
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = F.identity
|
||||
self.parallel_mode = _get_parallel_mode()
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(self.weights, self.mean, self.degree)
|
||||
self.use_grad_accumulation = False
|
||||
if self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE):
|
||||
self.use_grad_accumulation = True
|
||||
self.reducer_flag = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL)
|
||||
self.use_grad_accumulation = self.parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.STAND_ALONE)
|
||||
if self.use_grad_accumulation:
|
||||
self.max_accumulation_step = get_auto_parallel_context("grad_accumulation_step")
|
||||
if self.max_accumulation_step <= 1:
|
||||
self.max_accumulation_step = 1
|
||||
self.use_grad_accumulation = False
|
||||
self.grad_accumulation = None
|
||||
if self.use_grad_accumulation:
|
||||
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
|
||||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
if self.freeze:
|
||||
self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree)
|
||||
for opt in self.optimizer.opts)
|
||||
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer,
|
||||
self.use_grad_accumulation, self.max_accumulation_step, opt)
|
||||
for reducer, opt in zip(self.grad_reducers, self.optimizer))
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree)
|
||||
else:
|
||||
if self.freeze:
|
||||
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer,
|
||||
self.use_grad_accumulation, self.max_accumulation_step, opt)
|
||||
for opt in self.optimizer.opts)
|
||||
self.step = Parameter(Tensor(0, dtype=mstype.int32))
|
||||
|
||||
def construct(self, *inputs):
|
||||
weights = self.weights
|
||||
loss = self.network(*inputs)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.grad_accumulation(loss, grads)
|
||||
if self.freeze:
|
||||
if self.train_strategy is None:
|
||||
step = self.step
|
||||
max_index = len(self.freeze_nets)
|
||||
else:
|
||||
step = self.train_strategy[self.step]
|
||||
max_index = len(self.train_strategy)
|
||||
loss = self.freeze_nets[step](*inputs)
|
||||
if self.step + 1 >= max_index:
|
||||
self.step = 0
|
||||
else:
|
||||
self.step += 1
|
||||
else:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
loss = self.network(*inputs)
|
||||
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
||||
grads = self.grad(self.network, self.weights)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.grad_accumulation(loss, grads)
|
||||
else:
|
||||
loss = F.depend(loss, self.optimizer(grads))
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@ from .. import nn
|
|||
from .._checkparam import Validator as validator
|
||||
from .._checkparam import Rel
|
||||
from ..common import dtype as mstype
|
||||
from ..nn import acc
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..ops import functional as F
|
||||
from ..parallel._utils import _get_parallel_mode
|
||||
|
@ -139,7 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
|
|||
scale the loss by `LossScaleManager`. If set, overwrite the level setting.
|
||||
"""
|
||||
validator.check_value_type('network', network, nn.Cell)
|
||||
validator.check_value_type('optimizer', optimizer, nn.Optimizer)
|
||||
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, acc.FreezeOpt))
|
||||
validator.check('level', level, "", ['O0', 'O2', 'O3', "auto"], Rel.IN)
|
||||
|
||||
if level == "auto":
|
||||
|
|
Loading…
Reference in New Issue