!17597 [ACC]Gradient freeze optimization.

Merge pull request !17597 from linqingke/grad_freeze
This commit is contained in:
i-robot 2021-06-22 03:32:00 +00:00 committed by Gitee
commit a00cbb6e22
11 changed files with 598 additions and 160 deletions

View File

@ -31,11 +31,15 @@ constexpr auto kFirstBranchPattern1 = 12;
constexpr auto kSecondBranchPattern1 = 3;
constexpr auto kFirstBranchStartIndexPattern1 = 4;
constexpr auto kFirstBranchEndIndexPattern1 = 11;
constexpr auto kSecondBranchStartIndexPattern1 = 12;
constexpr auto kSecondBranchEndIndexPattern1 = 14;
const std::vector<kStructureTuple> ResidualStructureBasePattern{
{kFirstBranchPattern1,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
{kSecondBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
{kSecondBranchPattern1,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
// Pattern 2
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ -> -> ... ... ... -> -> ↗
@ -43,11 +47,13 @@ constexpr auto kFirstBranchPattern2 = 12;
constexpr auto kSecondBranchPattern2 = 1;
constexpr auto kFirstBranchStartIndexPattern2 = 4;
constexpr auto kFirstBranchEndIndexPattern2 = 11;
constexpr auto kSecondBranchStartIndexPattern2 = 12;
constexpr auto kSecondBranchEndIndexPattern2 = 13;
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
{kFirstBranchPattern2,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
{kSecondBranchPattern2, {prim::kPrimRelu}, {SIZE_MAX, SIZE_MAX}}};
{kSecondBranchPattern2, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
// Pattern 3
// Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
// ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗
@ -55,15 +61,65 @@ constexpr auto kFirstBranchPattern3 = 11;
constexpr auto kSecondBranchPattern3 = 3;
constexpr auto kFirstBranchStartIndexPattern3 = 4;
constexpr auto kFirstBranchEndIndexPattern3 = 10;
constexpr auto kSecondBranchStartIndexPattern3 = 11;
constexpr auto kSecondBranchEndIndexPattern3 = 13;
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
{kFirstBranchPattern3,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
prim::kPrimConv2D},
{kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
{kSecondBranchPattern3,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
// Pattern 4
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
constexpr auto kFirstBranchPattern4 = 8;
constexpr auto kSecondBranchPattern4 = 3;
constexpr auto kFirstBranchStartIndexPattern4 = 4;
constexpr auto kFirstBranchEndIndexPattern4 = 6;
constexpr auto kSecondBranchStartIndexPattern4 = 8;
constexpr auto kSecondBranchEndIndexPattern4 = 11;
const std::vector<kStructureTuple> BasicStructureBasePattern{
{kFirstBranchPattern4,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
{kSecondBranchPattern4,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
// Pattern 5
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ -> -> -> -> Relu -> -> -> -> ↗
constexpr auto kFirstBranchPattern5 = 8;
constexpr auto kSecondBranchPattern5 = 1;
constexpr auto kFirstBranchStartIndexPattern5 = 4;
constexpr auto kFirstBranchEndIndexPattern5 = 6;
constexpr auto kSecondBranchStartIndexPattern5 = 8;
constexpr auto kSecondBranchEndIndexPattern5 = 11;
const std::vector<kStructureTuple> BasicStructureShortCutPattern{
{kFirstBranchPattern5,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
{kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
// Pattern 6
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
// ↘ -> -> -> -> MaxPool -> -> -> ↗
constexpr auto kFirstBranchPattern6 = 7;
constexpr auto kSecondBranchPattern6 = 1;
constexpr auto kFirstBranchStartIndexPattern6 = 4;
constexpr auto kFirstBranchEndIndexPattern6 = 6;
constexpr auto kSecondBranchStartIndexPattern6 = 7;
constexpr auto kSecondBranchEndIndexPattern6 = 10;
const std::vector<kStructureTuple> BasicStructureFirstStepPattern{
{kFirstBranchPattern6,
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
prim::kPrimBatchNorm, prim::kPrimConv2D},
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
{kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern,
BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern};
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};

View File

@ -15,10 +15,16 @@
"""
Accelerating.
Provide auto accelerating for network, such as Less BN.
Provide auto accelerating for network, such as Less BN, Gradient Freeze.
"""
from .acc import *
from .base import *
from .less_batch_normalization import *
from .grad_freeze import *
__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence']
__all__ = ['AutoAcc',
'OptimizerProcess', 'ParameterProcess',
'LessBN',
'GradientFreeze', 'FreezeOpt', 'freeze_cell',
'GradientAccumulation']

125
mindspore/nn/acc/acc.py Normal file
View File

@ -0,0 +1,125 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""acc"""
from easydict import EasyDict as edict
from .less_batch_normalization import LessBN
from .grad_freeze import GradientFreeze
from .base import OptimizerProcess, ParameterProcess
__all__ = ["AutoAcc"]
_acc_config_level = {
"O0": {
"less_bn": False,
"grad_freeze": False},
"O1": {
"less_bn": True,
"grad_freeze": False}}
class AutoAcc:
"""
Provide auto accelerating for network.
Args:
level (Str): acc config level.
"""
def __init__(self, level, kwargs):
if level not in _acc_config_level.keys():
level = 'O0'
acc_config = _acc_config_level[level]
self._acc_config = edict(acc_config)
self._fn_flag = True
self._gc_flag = True
self._param_groups = 10
self._freeze_type = 1
self._freeze_p = 0.5
self._total_steps = -1
self._gradient_groups = None
self._get_configuration(kwargs)
self._param_processer = ParameterProcess()
def _get_configuration(self, kwargs):
"""Get configuration."""
for key, val in kwargs.items():
if key not in self._acc_config_func_map.keys():
continue
self._acc_config_func_map[key](self, val)
def network_auto_process_train(self, network, optimizer):
"""Network train."""
if self._acc_config.less_bn:
network = LessBN(network, fn_flag=self._fn_flag)
optimizer_process = OptimizerProcess(optimizer)
group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
self._gradient_groups)
optimizer_process.origin_params = self._param_processer.generate_group_params(group_params)
if self._gc_flag:
parameters = optimizer_process.add_grad_centralization()
else:
parameters = optimizer_process.origin_params
optimizer = optimizer_process.generate_new_optimizer(parameters)
if self._acc_config.grad_freeze:
freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
self._freeze_p, self._total_steps)
network, optimizer = freeze_processer.freeze_generate(network, optimizer)
return network, optimizer
def network_auto_process_eval(self, network):
"""Network eval."""
if self._acc_config.less_bn:
network = LessBN(network)
return network
def set_fn_flag(self, fn_flag):
self._fn_flag = fn_flag
def set_gc_flag(self, gc_flag):
self._gc_flag = gc_flag
def set_param_groups(self, param_groups):
self._param_groups = param_groups
def set_freeze_type(self, freeze_type):
self._freeze_type = freeze_type
def set_freeze_p(self, freeze_p):
self._freeze_p = freeze_p
def set_total_steps(self, total_steps):
self._total_steps = total_steps
def set_gradient_groups(self, gradient_groups):
if not isinstance(gradient_groups, (list, int)):
raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
if isinstance(gradient_groups, int):
gradient_groups = list(gradient_groups)
self._gradient_groups = gradient_groups
_acc_config_func_map = {
"fn_flag": set_fn_flag,
"gc_flag": set_gc_flag,
"param_groups": set_param_groups,
"freeze_type": set_freeze_type,
"freeze_p": set_freeze_p,
"total_steps": set_total_steps,
"gradient_groups": set_gradient_groups
}

164
mindspore/nn/acc/base.py Normal file
View File

@ -0,0 +1,164 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""base process"""
from mindspore.nn.cell import Cell
from mindspore.nn.optim import LARS
from mindspore import log as logger
from mindspore.common import Parameter
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
__all__ = ["OptimizerProcess", "ParameterProcess", "GradientAccumulation"]
class OptimizerProcess:
"""
Process optimizer for ACC.
Args:
opt (Cell): Optimizer used.
"""
def __init__(self, opt):
if isinstance(opt, LARS):
self.is_lars = True
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
else:
self.is_lars = False
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.origin_params = opt.init_params["params"]
def add_grad_centralization(self):
"""Add gradient centralization."""
parameters = self.origin_params
if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters)
if not parameters:
raise ValueError("Optimizer got an empty parameter list.")
if not isinstance(parameters[0], (dict, Parameter)):
raise TypeError("Only a list of Parameter or dict can be supported.")
if isinstance(parameters[0], Parameter):
logger.warning("Only group parameters support gradient centralization.")
return parameters
change_dict = parameters[0]
if 'order_params' in change_dict.keys():
logger.warning("Only support normal parameters for gradient centralization.")
return parameters
change_dict['grad_centralization'] = True
self.origin_params[0] = change_dict
return self.origin_params
def generate_new_optimizer(self, params):
"""Generate new optimizer."""
if not self.is_lars:
opt = self.opt_class(params=params, **self.opt_init_args)
else:
opt = LARS(self.opt_class(params=params, **self.opt_init_args), **self.lars_init_args)
return opt
class ParameterProcess:
"""
Process parameter for ACC.
"""
def __init__(self):
self._parameter_indices = 1
def assign_parameter_group(self, parameters, split_point=None):
"""Assign parameter group."""
if not isinstance(parameters, (list, tuple)) or not parameters:
return parameters
parameter_len = len(parameters)
if split_point:
split_parameter_index = split_point
else:
split_parameter_index = [parameter_len // 2]
for i in range(parameter_len):
if i in split_parameter_index:
self._parameter_indices += 1
parameters[i].comm_fusion = self._parameter_indices
return parameters
def generate_group_params(self, parameters):
"""Generate group parameters."""
decayed_params = []
no_decayed_params = []
for param in parameters:
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': 0.0001},
{'params': no_decayed_params},
{'order_params': parameters}]
return group_params
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@_gradient_clear_op.register("Tensor")
def _clear_grad(cumulative_grad):
zero_grad = P.ZerosLike()(cumulative_grad)
return F.assign(cumulative_grad, zero_grad)
class GradientAccumulation(Cell):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer (Cell): Optimizer used.
"""
def __init__(self, max_accumulation_step, optimizer):
super(GradientAccumulation, self).__init__()
self._max_accumulation_step = max_accumulation_step
self.optimizer = optimizer
self.weights = optimizer.parameters
self.hyper_map = C.HyperMap()
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
def construct(self, loss, grads):
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads))
self._accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
self._accumulation_step = 0
if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
return loss

View File

@ -18,68 +18,20 @@ import numpy as np
from mindspore.nn.cell import Cell
from mindspore.nn.optim import Optimizer
from mindspore.common import Tensor, Parameter
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.optim import LARS
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import functional as F
__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
'split_parameters_groups', 'generate_freeze_index_sequence',
'FreezeOpt']
from .base import GradientAccumulation, ParameterProcess
__all__ = ['GradientFreeze', 'FreezeOpt', 'freeze_cell']
CONTINUOUS_STRATEGY = 0
INTERVAL_STRATEGY = 1
def split_parameters_groups(net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training."""
grouped_params = []
tmp = []
for para in net.trainable_params():
name = para.name
# ensure 'bn' after 'conv' is not split
if 'bn' in name or 'bias' in name:
tmp.append(para)
elif len(tmp) >= 3:
grouped_params.append(tmp)
tmp = [para]
else:
tmp.append(para)
if tmp:
grouped_params.append(tmp)
stride = len(grouped_params) // freeze_para_groups_number
freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)]
return freeze_grouped_params
def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch):
"""Generate index sequence for gradient freezing training."""
total_step = steps_per_epoch * max_epoch * 1.01
# local continuous freezing training strategy, as '00001234'
if freeze_strategy == CONTINUOUS_STRATEGY:
zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
freeze_idxes = []
while len(freeze_idxes) < total_step:
freeze_idxes += sub_idx
return freeze_idxes
# interval freezing training strategy, as '01020304'
if freeze_strategy == INTERVAL_STRATEGY:
index_all = list(range(1, parameter_groups_number))
prob = [x / sum(index_all) for x in index_all]
freeze_idxes = [0]
zero_cnt = 1
freeze_cnt = 0
while len(freeze_idxes) < total_step:
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
if freeze_p_cur < 1 - freeze_p:
freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob)))
freeze_cnt += 1
else:
freeze_idxes.append(0)
zero_cnt += 1
return freeze_idxes
raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'")
class FreezeOpt(Cell):
"""
Optimizer that supports gradients freezing training.
@ -95,12 +47,20 @@ class FreezeOpt(Cell):
def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
super(FreezeOpt, self).__init__()
if not isinstance(opt, Optimizer):
raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
raise TypeError(
f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
if train_strategy is not None and train_parameter_groups is None:
raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
"must also be specified")
opt_class = type(opt)
opt_init_args = opt.init_args
if isinstance(opt, LARS):
self.is_lars = True
self.opt_class = type(opt.opt)
self.opt_init_args = opt.opt.init_args
self.lars_init_args = opt.init_args
else:
self.is_lars = False
self.opt_class = type(opt)
self.opt_init_args = opt.init_args
self.opts = []
if train_parameter_groups is None:
@ -108,25 +68,25 @@ class FreezeOpt(Cell):
step = 6
parameters = opt.parameters
para_groups = (parameters[(i * step):] for i in range(groups_num))
self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups]
self.opts = [self._generate_new_optimizer(
params) for params in para_groups]
else:
if not isinstance(train_parameter_groups, (tuple, list)):
raise TypeError("The specified 'train_parameter_groups' should be tuple or list")
raise TypeError(
"The specified 'train_parameter_groups' should be tuple or list")
for params in train_parameter_groups:
if not isinstance(params, (tuple, list)):
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
"to store the Parameter")
for para in params:
if not isinstance(para, Parameter):
raise TypeError("The element of each group should be the Parameter")
# generate one-to-one opt corresponding to the parameter group
self.opts.append(opt_class(params=params, **opt_init_args))
self.opts.append(self._generate_new_optimizer(params))
if isinstance(train_strategy, (tuple, list)):
for ele in train_strategy:
if not isinstance(ele, int):
raise ValueError("The element in train_strategy should be int number")
raise ValueError(
"The element in train_strategy should be int number")
self.train_strategy = Tensor(train_strategy, mstype.int32)
elif isinstance(train_strategy, Tensor):
if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
@ -136,4 +96,162 @@ class FreezeOpt(Cell):
elif train_strategy is None:
self.train_strategy = None
else:
raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor")
raise TypeError(
"The specified 'train_strategy' should be None, tuple, list or Tensor")
def _generate_new_optimizer(self, params):
"""Generate new optimizer."""
if not self.is_lars:
opt = self.opt_class(params=params, **self.opt_init_args)
else:
opt = LARS(self.opt_class(params=params, **self.opt_init_args),
**self.lars_init_args)
return opt
class _TrainFreezeCell(Cell):
"""
Gradient freezing training network.
Args:
net (Cell): The training network.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
grad (tuple(tensor)): The gradients of network parameters and inputs.
grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on
single-process gradient values.
use_grad_accumulation (Bool): Whether use grad accumulation.
optimizer (Union[Cell]): Optimizer for updating the weights.
max_accumulation_step (Number): Max grad accumulation steps. Default: 1.0
Supported Platforms:
``Ascend``
"""
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, optimizer, max_accumulation_step=1):
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
self.net = net
self.grad = grad
self.grad_reducer = grad_reducer
self.opt = optimizer
self.parameters = optimizer.parameters
self.sens = sens
self.use_grad_accumulation = use_grad_accumulation
self.max_accumulation_step = max_accumulation_step
if use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(
self.max_accumulation_step, self.optimizer)
def construct(self, *inputs):
loss = self.net(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.net, self.parameters)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.opt(grads))
return loss
class GradientFreeze:
"""
Freezing the gradients of some layers randomly. The number and
probability of frozen layers can be configured by users
Args:
param_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
freeze_type (Int): Strategy of gradients freezing training.
freeze_p (FLoat): probability of gradients freezing training.
total_steps (Number): Steps of the whole training.
Examples:
>>> gradient_freeze_class = acc.GradientFreeze(10, 1, 0.5, 2000)
>>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer)
"""
def __init__(self, param_groups, freeze_type, freeze_p, total_steps):
self._param_groups = param_groups
self._freeze_type = freeze_type
self._freeze_p = freeze_p
self._total_steps = total_steps
self.grad_reducer = F.identity
self._param_processer = ParameterProcess()
def split_parameters_groups(self, net, freeze_para_groups_number):
"""Split parameter groups for gradients freezing training."""
grouped_params = []
tmp = []
for para in net.trainable_params():
name = para.name
# ensure 'bn' after 'conv' is not split
if 'bn' in name or 'bias' in name:
tmp.append(para)
elif len(tmp) >= 3:
grouped_params.append(tmp)
tmp = [para]
else:
tmp.append(para)
if tmp:
grouped_params.append(tmp)
stride = len(grouped_params) // freeze_para_groups_number
freeze_grouped_params = [sum(grouped_params[i * stride:], [])
for i in range(freeze_para_groups_number)]
return freeze_grouped_params
def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
"""Generate index sequence for gradient freezing training."""
total_step = total_steps * 1.01
# local continuous freezing training strategy, as '00001234'
if freeze_strategy == CONTINUOUS_STRATEGY:
zero_cnt = int(
freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
freeze_idxes = []
while len(freeze_idxes) < total_step:
freeze_idxes += sub_idx
return freeze_idxes
# interval freezing training strategy, as '01020304'
if freeze_strategy == INTERVAL_STRATEGY:
index_all = list(range(1, parameter_groups_number))
prob = [x / sum(index_all) for x in index_all]
freeze_idxes = [0]
zero_cnt = 1
freeze_cnt = 0
while len(freeze_idxes) < total_step:
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
if freeze_p_cur < 1 - freeze_p:
freeze_idxes.append(
int(np.random.choice(index_all[::-1], p=prob)))
freeze_cnt += 1
else:
freeze_idxes.append(0)
zero_cnt += 1
return freeze_idxes
raise ValueError(
f"Unsupported freezing training strategy '{freeze_strategy}'")
def freeze_generate(self, network, optimizer):
"""Generate freeze network and optimizer."""
train_para_groups = self.split_parameters_groups(
network, self._param_groups)
for i in range(self._param_groups):
train_para_groups[i] = self._param_processer.generate_group_params(train_para_groups[i])
train_strategy = self.generate_freeze_index_sequence(
self._param_groups, self._freeze_type, self._freeze_p, self._total_steps)
optimizer = FreezeOpt(optimizer, train_para_groups, train_strategy)
return network, optimizer
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
max_accumulation_step=1):
"""Provide freeze network cell."""
if reducer_flag:
param_processer = ParameterProcess()
grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),
mean, degree, param_fusion=True) for opt in optimizer.opts)
freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, reducer,
use_grad_accumulation, opt, max_accumulation_step)
for reducer, opt in zip(grad_reducers, optimizer.opts))
else:
freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, self.grad_reducer,
use_grad_accumulation, opt, max_accumulation_step)
for opt in optimizer.opts)
return freeze_nets

View File

@ -93,12 +93,13 @@ class LessBN(Cell):
>>> network = acc.LessBN(network)
"""
def __init__(self, network):
def __init__(self, network, fn_flag=False):
super(LessBN, self).__init__()
self.network = network
self.network.set_acc("less_bn")
self.network.update_cell_prefix()
self._convert_to_less_bn_net(self.network)
if fn_flag:
self._convert_to_less_bn_net(self.network)
self.network.add_flags(defer_inline=True)
def _convert_dense(self, subcell):
@ -110,7 +111,7 @@ class LessBN(Cell):
subcell.out_channels,
subcell.weight,
subcell.bias,
subcell.has_bias)
False)
new_subcell.update_parameters_name(prefix + '.')
return new_subcell

View File

@ -19,6 +19,7 @@ from mindspore.ops import functional as F
from mindspore._checkparam import Validator as validator
from mindspore.common import Tensor, Parameter, dtype as mstype
from .optimizer import _grad_scale, Optimizer
from .optimizer import opt_init_args_register
_lars_opt = C.MultitypeFuncGraph("lars_opt")
@ -100,6 +101,7 @@ class LARS(Optimizer):
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
"""
@opt_init_args_register
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False,
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])

8
mindspore/nn/optim/optimizer.py Executable file → Normal file
View File

@ -37,12 +37,18 @@ from mindspore.nn.learning_rate_schedule import LearningRateSchedule
__all__ = ['Optimizer', 'opt_init_args_register']
def opt_init_args_register(fn):
"""Register optimizer init args."""
def deco(self, *args, **kwargs):
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
bound_args.apply_defaults()
arguments = bound_args.arguments
arguments.pop('self')
arguments.pop('params')
if 'params' in arguments.keys():
setattr(self, 'init_params', dict({"params": arguments['params']}))
arguments.pop('params')
if 'optimizer' in arguments.keys():
setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]}))
arguments.pop('optimizer')
setattr(self, 'init_args', arguments)
fn(self, *args, **kwargs)
return deco

View File

@ -65,19 +65,6 @@ def _tensors_cast_datatype(datatype, param):
"""
return F.cast(param, datatype)
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
"""Apply gradient accumulation to cumulative grad."""
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
@_gradient_clear_op.register("Tensor")
def _clear_grad(cumulative_grad):
zero_grad = P.ZerosLike()(cumulative_grad)
return F.assign(cumulative_grad, zero_grad)
class WithLossCell(Cell):
r"""
@ -294,32 +281,6 @@ class ForwardValueAndGrad(Cell):
return loss, grads
class _TrainFreezeCell(Cell):
"""Gradient freezing training network."""
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer):
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
self.net = net
self.grad = grad
self.grad_reducer = grad_reducer
self.opt = optimizer
self.parameters = optimizer.parameters
self.sens = sens
self.use_grad_accumulation = use_grad_accumulation
if use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer)
def construct(self, *inputs):
loss = self.net(*inputs)
sens = F.fill(loss.dtype, loss.shape, self.sens)
grads = self.grad(self.net, self.parameters)(*inputs, sens)
grads = self.grad_reducer(grads)
if self.use_grad_accumulation:
loss = self.grad_accumulation(loss, grads)
else:
loss = F.depend(loss, self.opt(grads))
return loss
class TrainOneStepCell(Cell):
r"""
Network training package class.
@ -395,23 +356,20 @@ class TrainOneStepCell(Cell):
self.use_grad_accumulation = False
self.grad_accumulation = None
if self.use_grad_accumulation:
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
self.grad_accumulation = acc.GradientAccumulation(self.max_accumulation_step, self.optimizer)
if self.reducer_flag:
self.mean = _get_gradients_mean()
self.degree = _get_device_num()
if self.freeze:
self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree)
for opt in self.optimizer.opts)
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for reducer, opt in zip(self.grad_reducers, self.optimizer))
self.freeze_nets = acc.freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens,
self.grad, self.use_grad_accumulation, self.mean, self.degree,
self.max_accumulation_step)
else:
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree)
else:
if self.freeze:
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer,
self.use_grad_accumulation, self.max_accumulation_step, opt)
for opt in self.optimizer.opts)
self.freeze_nets = acc.freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens,
self.grad, self.use_grad_accumulation, self.max_accumulation_step)
self.step = Parameter(Tensor(0, dtype=mstype.int32))
def construct(self, *inputs):
@ -726,34 +684,3 @@ class _BroadCastCell(Cell):
params = self.broadcast(params)
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
return new_params
class GradientAccumulation(Cell):
"""
After accumulating the gradients of multiple steps, call to optimize its update.
Args:
max_accumulation_step (int): Steps to accumulate gradients.
optimizer(Cell):Optimizer used.
"""
def __init__(self, max_accumulation_step, optimizer):
super(GradientAccumulation, self).__init__()
self._max_accumulation_step = max_accumulation_step
self.optimizer = optimizer
self.weights = optimizer.parameters
self.hyper_map = C.HyperMap()
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
def construct(self, loss, grads):
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
self._grad_accumulation, grads))
self._accumulation_step += 1
if self._accumulation_step >= self._max_accumulation_step:
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
self._accumulation_step = 0
if self._accumulation_step == 0:
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
return loss

View File

@ -46,6 +46,18 @@ def _init_allreduce_operators(length, split_indices):
return op_list
def _init_allreduce_operators_by_parameters(parameters):
""" initialize allreduce communication operators by parameters"""
op_list = ()
for parameter in parameters:
comm_fusion = parameter.comm_fusion
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
op.add_prim_attr('fusion', comm_fusion)
op.add_prim_attr('index', comm_fusion)
op_list = op_list + (op,)
return op_list
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
"""
@ -334,7 +346,7 @@ class DistributedGradReducer(Cell):
256.0
"""
def __init__(self, parameters, mean=True, degree=None, fusion_type=1):
def __init__(self, parameters, mean=True, degree=None, fusion_type=1, param_fusion=False):
super(DistributedGradReducer, self).__init__(auto_prefix=False)
self.map_ = C.Map()
if degree is None:
@ -351,6 +363,9 @@ class DistributedGradReducer(Cell):
if is_parallel_optimizer and split_indices:
self.split_fusion = True
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
elif param_fusion:
self.split_fusion = True
self.op_list = _init_allreduce_operators_by_parameters(parameters)
else:
self.split_fusion = False
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)

24
mindspore/train/model.py Executable file → Normal file
View File

@ -31,6 +31,7 @@ from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
from ..nn.acc import acc
from ..context import ParallelMode
from ..parallel._cost_model_context import _set_multi_subgraphs
from .dataset_helper import DatasetHelper, connect_network_with_dataset
@ -124,7 +125,7 @@ class Model:
"""
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
eval_indexes=None, amp_level="O0", **kwargs):
eval_indexes=None, amp_level="O0", acc_level="O0", **kwargs):
self._network = network
self._loss_fn = loss_fn
self._optimizer = optimizer
@ -133,6 +134,8 @@ class Model:
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
self._amp_level = amp_level
self._acc_level = acc_level
self._eval_network = eval_network
self._process_amp_args(kwargs)
self._parallel_mode = _get_parallel_mode()
self._device_number = _get_device_num()
@ -141,8 +144,9 @@ class Model:
self._check_amp_level_arg(optimizer, amp_level)
self._check_for_graph_cell(kwargs)
self._build_acc_network(kwargs)
self._train_network = self._build_train_network()
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_eval_network(metrics, self._eval_network, eval_indexes)
self._build_predict_network()
def _check_for_graph_cell(self, kwargs):
@ -173,7 +177,7 @@ class Model:
def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32', 'total_steps']:
raise ValueError(f"Unsupported arg '{arg}'")
def _check_reuse_dataset(self, dataset):
@ -182,11 +186,25 @@ class Model:
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.')
def _build_acc_network(self, kwargs):
"""Build the acc network."""
processor = acc.AutoAcc(self._acc_level, kwargs)
if self._optimizer is None:
logger.warning("In acc mode, the optimizer must be defined.")
return
if self._eval_network is None:
logger.warning("In acc mode, the eval_network must be defined.")
return
self._network, self._optimizer = processor.network_auto_process_train(self._network, self._optimizer)
self._eval_network = processor.network_auto_process_eval(self._eval_network)
def _build_train_network(self):
"""Build train network"""
network = self._network
if self._loss_scale_manager is not None and self._optimizer is None:
raise ValueError("Optimizer can not be None when set loss_scale_manager.")
if self._optimizer:
if self._loss_scale_manager_set:
network = amp.build_train_network(network,