forked from mindspore-Ecosystem/mindspore
!17597 [ACC]Gradient freeze optimization.
Merge pull request !17597 from linqingke/grad_freeze
This commit is contained in:
commit
a00cbb6e22
|
@ -31,11 +31,15 @@ constexpr auto kFirstBranchPattern1 = 12;
|
|||
constexpr auto kSecondBranchPattern1 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern1 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern1 = 11;
|
||||
constexpr auto kSecondBranchStartIndexPattern1 = 12;
|
||||
constexpr auto kSecondBranchEndIndexPattern1 = 14;
|
||||
const std::vector<kStructureTuple> ResidualStructureBasePattern{
|
||||
{kFirstBranchPattern1,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern1, kFirstBranchEndIndexPattern1}},
|
||||
{kSecondBranchPattern1, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||
{kSecondBranchPattern1,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kSecondBranchStartIndexPattern1, kSecondBranchEndIndexPattern1}}};
|
||||
// Pattern 2
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> ... ... ... -> -> ↗
|
||||
|
@ -43,11 +47,13 @@ constexpr auto kFirstBranchPattern2 = 12;
|
|||
constexpr auto kSecondBranchPattern2 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern2 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern2 = 11;
|
||||
constexpr auto kSecondBranchStartIndexPattern2 = 12;
|
||||
constexpr auto kSecondBranchEndIndexPattern2 = 13;
|
||||
const std::vector<kStructureTuple> ResidualStructureShortCutPattern{
|
||||
{kFirstBranchPattern2,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern2, kFirstBranchEndIndexPattern2}},
|
||||
{kSecondBranchPattern2, {prim::kPrimRelu}, {SIZE_MAX, SIZE_MAX}}};
|
||||
{kSecondBranchPattern2, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern2, kSecondBranchEndIndexPattern2}}};
|
||||
// Pattern 3
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... BatchNorm -> Conv2D -> End
|
||||
// ↘ BatchNorm -> Conv2D -> -> ... ... ... -> -> ↗
|
||||
|
@ -55,15 +61,65 @@ constexpr auto kFirstBranchPattern3 = 11;
|
|||
constexpr auto kSecondBranchPattern3 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern3 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern3 = 10;
|
||||
constexpr auto kSecondBranchStartIndexPattern3 = 11;
|
||||
constexpr auto kSecondBranchEndIndexPattern3 = 13;
|
||||
const std::vector<kStructureTuple> ResidualStructureFirstStepPattern{
|
||||
{kFirstBranchPattern3,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem, prim::kPrimBatchNorm,
|
||||
prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern3, kFirstBranchEndIndexPattern3}},
|
||||
{kSecondBranchPattern3, {prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D}, {SIZE_MAX, SIZE_MAX}}};
|
||||
{kSecondBranchPattern3,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kSecondBranchStartIndexPattern3, kSecondBranchEndIndexPattern3}}};
|
||||
// Pattern 4
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ BatchNorm -> Conv2D -> -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern4 = 8;
|
||||
constexpr auto kSecondBranchPattern4 = 3;
|
||||
constexpr auto kFirstBranchStartIndexPattern4 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern4 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern4 = 8;
|
||||
constexpr auto kSecondBranchEndIndexPattern4 = 11;
|
||||
const std::vector<kStructureTuple> BasicStructureBasePattern{
|
||||
{kFirstBranchPattern4,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern4, kFirstBranchEndIndexPattern4}},
|
||||
{kSecondBranchPattern4,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kSecondBranchStartIndexPattern4, kSecondBranchEndIndexPattern4}}};
|
||||
// Pattern 5
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> -> -> Relu -> -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern5 = 8;
|
||||
constexpr auto kSecondBranchPattern5 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern5 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern5 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern5 = 8;
|
||||
constexpr auto kSecondBranchEndIndexPattern5 = 11;
|
||||
const std::vector<kStructureTuple> BasicStructureShortCutPattern{
|
||||
{kFirstBranchPattern5,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu},
|
||||
{kFirstBranchStartIndexPattern5, kFirstBranchEndIndexPattern5}},
|
||||
{kSecondBranchPattern5, {prim::kPrimRelu}, {kSecondBranchStartIndexPattern5, kSecondBranchEndIndexPattern5}}};
|
||||
// Pattern 6
|
||||
// Add -> BatchNorm -> Conv2D -> Relu ... -> End
|
||||
// ↘ -> -> -> -> MaxPool -> -> -> ↗
|
||||
constexpr auto kFirstBranchPattern6 = 7;
|
||||
constexpr auto kSecondBranchPattern6 = 1;
|
||||
constexpr auto kFirstBranchStartIndexPattern6 = 4;
|
||||
constexpr auto kFirstBranchEndIndexPattern6 = 6;
|
||||
constexpr auto kSecondBranchStartIndexPattern6 = 7;
|
||||
constexpr auto kSecondBranchEndIndexPattern6 = 10;
|
||||
const std::vector<kStructureTuple> BasicStructureFirstStepPattern{
|
||||
{kFirstBranchPattern6,
|
||||
{prim::kPrimTupleGetItem, prim::kPrimBatchNorm, prim::kPrimConv2D, prim::kPrimRelu, prim::kPrimTupleGetItem,
|
||||
prim::kPrimBatchNorm, prim::kPrimConv2D},
|
||||
{kFirstBranchStartIndexPattern6, kFirstBranchEndIndexPattern6}},
|
||||
{kSecondBranchPattern6, {prim::kPrimMaxPool}, {kSecondBranchStartIndexPattern6, kSecondBranchEndIndexPattern6}}};
|
||||
static const std::vector<std::vector<kStructureTuple>> kNeedMatchPattern = {
|
||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern};
|
||||
ResidualStructureBasePattern, ResidualStructureShortCutPattern, ResidualStructureFirstStepPattern,
|
||||
BasicStructureBasePattern, BasicStructureShortCutPattern, BasicStructureFirstStepPattern};
|
||||
const std::set<PrimitivePtr> kNeedRemoveNodeSet{
|
||||
prim::kPrimLoad, prim::kPrimRefToEmbed, prim::kPrimApplyMomentum, prim::kPrimMomentum,
|
||||
prim::kPrimApplyFtrl, prim::kPrimSGD, prim::kPrimApplyRMSProp, prim::kPrimAdam};
|
||||
|
|
|
@ -15,10 +15,16 @@
|
|||
"""
|
||||
Accelerating.
|
||||
|
||||
Provide auto accelerating for network, such as Less BN.
|
||||
Provide auto accelerating for network, such as Less BN, Gradient Freeze.
|
||||
"""
|
||||
from .acc import *
|
||||
from .base import *
|
||||
from .less_batch_normalization import *
|
||||
from .grad_freeze import *
|
||||
|
||||
__all__ = ['LessBN', 'FreezeOpt', 'CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
|
||||
'split_parameters_groups', 'generate_freeze_index_sequence']
|
||||
|
||||
__all__ = ['AutoAcc',
|
||||
'OptimizerProcess', 'ParameterProcess',
|
||||
'LessBN',
|
||||
'GradientFreeze', 'FreezeOpt', 'freeze_cell',
|
||||
'GradientAccumulation']
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""acc"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
from .less_batch_normalization import LessBN
|
||||
from .grad_freeze import GradientFreeze
|
||||
from .base import OptimizerProcess, ParameterProcess
|
||||
|
||||
|
||||
__all__ = ["AutoAcc"]
|
||||
|
||||
|
||||
_acc_config_level = {
|
||||
"O0": {
|
||||
"less_bn": False,
|
||||
"grad_freeze": False},
|
||||
"O1": {
|
||||
"less_bn": True,
|
||||
"grad_freeze": False}}
|
||||
|
||||
|
||||
class AutoAcc:
|
||||
"""
|
||||
Provide auto accelerating for network.
|
||||
|
||||
Args:
|
||||
level (Str): acc config level.
|
||||
"""
|
||||
def __init__(self, level, kwargs):
|
||||
if level not in _acc_config_level.keys():
|
||||
level = 'O0'
|
||||
acc_config = _acc_config_level[level]
|
||||
self._acc_config = edict(acc_config)
|
||||
self._fn_flag = True
|
||||
self._gc_flag = True
|
||||
self._param_groups = 10
|
||||
self._freeze_type = 1
|
||||
self._freeze_p = 0.5
|
||||
self._total_steps = -1
|
||||
self._gradient_groups = None
|
||||
self._get_configuration(kwargs)
|
||||
self._param_processer = ParameterProcess()
|
||||
|
||||
def _get_configuration(self, kwargs):
|
||||
"""Get configuration."""
|
||||
for key, val in kwargs.items():
|
||||
if key not in self._acc_config_func_map.keys():
|
||||
continue
|
||||
self._acc_config_func_map[key](self, val)
|
||||
|
||||
def network_auto_process_train(self, network, optimizer):
|
||||
"""Network train."""
|
||||
if self._acc_config.less_bn:
|
||||
network = LessBN(network, fn_flag=self._fn_flag)
|
||||
optimizer_process = OptimizerProcess(optimizer)
|
||||
group_params = self._param_processer.assign_parameter_group(network.trainable_params(),
|
||||
self._gradient_groups)
|
||||
optimizer_process.origin_params = self._param_processer.generate_group_params(group_params)
|
||||
if self._gc_flag:
|
||||
parameters = optimizer_process.add_grad_centralization()
|
||||
else:
|
||||
parameters = optimizer_process.origin_params
|
||||
optimizer = optimizer_process.generate_new_optimizer(parameters)
|
||||
|
||||
if self._acc_config.grad_freeze:
|
||||
freeze_processer = GradientFreeze(self._param_groups, self._freeze_type,
|
||||
self._freeze_p, self._total_steps)
|
||||
network, optimizer = freeze_processer.freeze_generate(network, optimizer)
|
||||
|
||||
return network, optimizer
|
||||
|
||||
def network_auto_process_eval(self, network):
|
||||
"""Network eval."""
|
||||
if self._acc_config.less_bn:
|
||||
network = LessBN(network)
|
||||
|
||||
return network
|
||||
|
||||
def set_fn_flag(self, fn_flag):
|
||||
self._fn_flag = fn_flag
|
||||
|
||||
def set_gc_flag(self, gc_flag):
|
||||
self._gc_flag = gc_flag
|
||||
|
||||
def set_param_groups(self, param_groups):
|
||||
self._param_groups = param_groups
|
||||
|
||||
def set_freeze_type(self, freeze_type):
|
||||
self._freeze_type = freeze_type
|
||||
|
||||
def set_freeze_p(self, freeze_p):
|
||||
self._freeze_p = freeze_p
|
||||
|
||||
def set_total_steps(self, total_steps):
|
||||
self._total_steps = total_steps
|
||||
|
||||
def set_gradient_groups(self, gradient_groups):
|
||||
if not isinstance(gradient_groups, (list, int)):
|
||||
raise ValueError(f"gradient_groups `{gradient_groups}` is not in (list, int)")
|
||||
if isinstance(gradient_groups, int):
|
||||
gradient_groups = list(gradient_groups)
|
||||
self._gradient_groups = gradient_groups
|
||||
|
||||
_acc_config_func_map = {
|
||||
"fn_flag": set_fn_flag,
|
||||
"gc_flag": set_gc_flag,
|
||||
"param_groups": set_param_groups,
|
||||
"freeze_type": set_freeze_type,
|
||||
"freeze_p": set_freeze_p,
|
||||
"total_steps": set_total_steps,
|
||||
"gradient_groups": set_gradient_groups
|
||||
}
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""base process"""
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.optim import LARS
|
||||
from mindspore import log as logger
|
||||
from mindspore.common import Parameter
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
__all__ = ["OptimizerProcess", "ParameterProcess", "GradientAccumulation"]
|
||||
|
||||
|
||||
class OptimizerProcess:
|
||||
"""
|
||||
Process optimizer for ACC.
|
||||
|
||||
Args:
|
||||
opt (Cell): Optimizer used.
|
||||
"""
|
||||
def __init__(self, opt):
|
||||
if isinstance(opt, LARS):
|
||||
self.is_lars = True
|
||||
self.opt_class = type(opt.opt)
|
||||
self.opt_init_args = opt.opt.init_args
|
||||
self.lars_init_args = opt.init_args
|
||||
else:
|
||||
self.is_lars = False
|
||||
self.opt_class = type(opt)
|
||||
self.opt_init_args = opt.init_args
|
||||
self.origin_params = opt.init_params["params"]
|
||||
|
||||
def add_grad_centralization(self):
|
||||
"""Add gradient centralization."""
|
||||
parameters = self.origin_params
|
||||
if parameters is not None and not isinstance(parameters, list):
|
||||
parameters = list(parameters)
|
||||
|
||||
if not parameters:
|
||||
raise ValueError("Optimizer got an empty parameter list.")
|
||||
|
||||
if not isinstance(parameters[0], (dict, Parameter)):
|
||||
raise TypeError("Only a list of Parameter or dict can be supported.")
|
||||
|
||||
if isinstance(parameters[0], Parameter):
|
||||
logger.warning("Only group parameters support gradient centralization.")
|
||||
return parameters
|
||||
|
||||
change_dict = parameters[0]
|
||||
if 'order_params' in change_dict.keys():
|
||||
logger.warning("Only support normal parameters for gradient centralization.")
|
||||
return parameters
|
||||
|
||||
change_dict['grad_centralization'] = True
|
||||
self.origin_params[0] = change_dict
|
||||
|
||||
return self.origin_params
|
||||
|
||||
def generate_new_optimizer(self, params):
|
||||
"""Generate new optimizer."""
|
||||
if not self.is_lars:
|
||||
opt = self.opt_class(params=params, **self.opt_init_args)
|
||||
else:
|
||||
opt = LARS(self.opt_class(params=params, **self.opt_init_args), **self.lars_init_args)
|
||||
|
||||
return opt
|
||||
|
||||
|
||||
class ParameterProcess:
|
||||
"""
|
||||
Process parameter for ACC.
|
||||
"""
|
||||
def __init__(self):
|
||||
self._parameter_indices = 1
|
||||
|
||||
def assign_parameter_group(self, parameters, split_point=None):
|
||||
"""Assign parameter group."""
|
||||
if not isinstance(parameters, (list, tuple)) or not parameters:
|
||||
return parameters
|
||||
|
||||
parameter_len = len(parameters)
|
||||
if split_point:
|
||||
split_parameter_index = split_point
|
||||
else:
|
||||
split_parameter_index = [parameter_len // 2]
|
||||
for i in range(parameter_len):
|
||||
if i in split_parameter_index:
|
||||
self._parameter_indices += 1
|
||||
parameters[i].comm_fusion = self._parameter_indices
|
||||
return parameters
|
||||
|
||||
def generate_group_params(self, parameters):
|
||||
"""Generate group parameters."""
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in parameters:
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
group_params = [{'params': decayed_params, 'weight_decay': 0.0001},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': parameters}]
|
||||
|
||||
return group_params
|
||||
|
||||
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
||||
|
||||
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
||||
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
||||
"""Apply gradient accumulation to cumulative grad."""
|
||||
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
||||
|
||||
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
||||
|
||||
@_gradient_clear_op.register("Tensor")
|
||||
def _clear_grad(cumulative_grad):
|
||||
zero_grad = P.ZerosLike()(cumulative_grad)
|
||||
return F.assign(cumulative_grad, zero_grad)
|
||||
|
||||
class GradientAccumulation(Cell):
|
||||
"""
|
||||
After accumulating the gradients of multiple steps, call to optimize its update.
|
||||
|
||||
Args:
|
||||
max_accumulation_step (int): Steps to accumulate gradients.
|
||||
optimizer (Cell): Optimizer used.
|
||||
"""
|
||||
def __init__(self, max_accumulation_step, optimizer):
|
||||
super(GradientAccumulation, self).__init__()
|
||||
self._max_accumulation_step = max_accumulation_step
|
||||
self.optimizer = optimizer
|
||||
self.weights = optimizer.parameters
|
||||
self.hyper_map = C.HyperMap()
|
||||
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
|
||||
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
|
||||
|
||||
def construct(self, loss, grads):
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
|
||||
self._grad_accumulation, grads))
|
||||
self._accumulation_step += 1
|
||||
|
||||
if self._accumulation_step >= self._max_accumulation_step:
|
||||
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
|
||||
self._accumulation_step = 0
|
||||
|
||||
if self._accumulation_step == 0:
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
|
||||
|
||||
return loss
|
|
@ -18,68 +18,20 @@ import numpy as np
|
|||
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.optim import Optimizer
|
||||
from mindspore.common import Tensor, Parameter
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.nn.optim import LARS
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
__all__ = ['CONTINUOUS_STRATEGY', 'INTERVAL_STRATEGY',
|
||||
'split_parameters_groups', 'generate_freeze_index_sequence',
|
||||
'FreezeOpt']
|
||||
from .base import GradientAccumulation, ParameterProcess
|
||||
|
||||
__all__ = ['GradientFreeze', 'FreezeOpt', 'freeze_cell']
|
||||
|
||||
CONTINUOUS_STRATEGY = 0
|
||||
INTERVAL_STRATEGY = 1
|
||||
|
||||
|
||||
def split_parameters_groups(net, freeze_para_groups_number):
|
||||
"""Split parameter groups for gradients freezing training."""
|
||||
grouped_params = []
|
||||
tmp = []
|
||||
for para in net.trainable_params():
|
||||
name = para.name
|
||||
# ensure 'bn' after 'conv' is not split
|
||||
if 'bn' in name or 'bias' in name:
|
||||
tmp.append(para)
|
||||
elif len(tmp) >= 3:
|
||||
grouped_params.append(tmp)
|
||||
tmp = [para]
|
||||
else:
|
||||
tmp.append(para)
|
||||
if tmp:
|
||||
grouped_params.append(tmp)
|
||||
stride = len(grouped_params) // freeze_para_groups_number
|
||||
freeze_grouped_params = [sum(grouped_params[i * stride:], []) for i in range(freeze_para_groups_number)]
|
||||
return freeze_grouped_params
|
||||
|
||||
|
||||
def generate_freeze_index_sequence(parameter_groups_number, freeze_strategy, freeze_p, steps_per_epoch, max_epoch):
|
||||
"""Generate index sequence for gradient freezing training."""
|
||||
total_step = steps_per_epoch * max_epoch * 1.01
|
||||
# local continuous freezing training strategy, as '00001234'
|
||||
if freeze_strategy == CONTINUOUS_STRATEGY:
|
||||
zero_cnt = int(freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
|
||||
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
|
||||
freeze_idxes = []
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_idxes += sub_idx
|
||||
return freeze_idxes
|
||||
# interval freezing training strategy, as '01020304'
|
||||
if freeze_strategy == INTERVAL_STRATEGY:
|
||||
index_all = list(range(1, parameter_groups_number))
|
||||
prob = [x / sum(index_all) for x in index_all]
|
||||
freeze_idxes = [0]
|
||||
zero_cnt = 1
|
||||
freeze_cnt = 0
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
|
||||
if freeze_p_cur < 1 - freeze_p:
|
||||
freeze_idxes.append(int(np.random.choice(index_all[::-1], p=prob)))
|
||||
freeze_cnt += 1
|
||||
else:
|
||||
freeze_idxes.append(0)
|
||||
zero_cnt += 1
|
||||
return freeze_idxes
|
||||
raise ValueError(f"Unsupported freezing training strategy '{freeze_strategy}'")
|
||||
|
||||
|
||||
class FreezeOpt(Cell):
|
||||
"""
|
||||
Optimizer that supports gradients freezing training.
|
||||
|
@ -95,12 +47,20 @@ class FreezeOpt(Cell):
|
|||
def __init__(self, opt, train_parameter_groups=None, train_strategy=None):
|
||||
super(FreezeOpt, self).__init__()
|
||||
if not isinstance(opt, Optimizer):
|
||||
raise TypeError(f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
|
||||
raise TypeError(
|
||||
f"The first arg 'opt' must be an Optimizer instance, but got {type(opt)}")
|
||||
if train_strategy is not None and train_parameter_groups is None:
|
||||
raise ValueError("When the 'train_strategy' is specified, the value of 'train_parameter_groups' "
|
||||
"must also be specified")
|
||||
opt_class = type(opt)
|
||||
opt_init_args = opt.init_args
|
||||
if isinstance(opt, LARS):
|
||||
self.is_lars = True
|
||||
self.opt_class = type(opt.opt)
|
||||
self.opt_init_args = opt.opt.init_args
|
||||
self.lars_init_args = opt.init_args
|
||||
else:
|
||||
self.is_lars = False
|
||||
self.opt_class = type(opt)
|
||||
self.opt_init_args = opt.init_args
|
||||
self.opts = []
|
||||
|
||||
if train_parameter_groups is None:
|
||||
|
@ -108,25 +68,25 @@ class FreezeOpt(Cell):
|
|||
step = 6
|
||||
parameters = opt.parameters
|
||||
para_groups = (parameters[(i * step):] for i in range(groups_num))
|
||||
self.opts = [opt_class(params=params, **opt_init_args) for params in para_groups]
|
||||
self.opts = [self._generate_new_optimizer(
|
||||
params) for params in para_groups]
|
||||
else:
|
||||
if not isinstance(train_parameter_groups, (tuple, list)):
|
||||
raise TypeError("The specified 'train_parameter_groups' should be tuple or list")
|
||||
raise TypeError(
|
||||
"The specified 'train_parameter_groups' should be tuple or list")
|
||||
for params in train_parameter_groups:
|
||||
if not isinstance(params, (tuple, list)):
|
||||
raise TypeError("The each element of 'train_parameter_groups' should be tuple or list "
|
||||
"to store the Parameter")
|
||||
for para in params:
|
||||
if not isinstance(para, Parameter):
|
||||
raise TypeError("The element of each group should be the Parameter")
|
||||
|
||||
# generate one-to-one opt corresponding to the parameter group
|
||||
self.opts.append(opt_class(params=params, **opt_init_args))
|
||||
self.opts.append(self._generate_new_optimizer(params))
|
||||
|
||||
if isinstance(train_strategy, (tuple, list)):
|
||||
for ele in train_strategy:
|
||||
if not isinstance(ele, int):
|
||||
raise ValueError("The element in train_strategy should be int number")
|
||||
raise ValueError(
|
||||
"The element in train_strategy should be int number")
|
||||
self.train_strategy = Tensor(train_strategy, mstype.int32)
|
||||
elif isinstance(train_strategy, Tensor):
|
||||
if train_strategy.ndim != 1 or train_strategy.dtype != mstype.int32:
|
||||
|
@ -136,4 +96,162 @@ class FreezeOpt(Cell):
|
|||
elif train_strategy is None:
|
||||
self.train_strategy = None
|
||||
else:
|
||||
raise TypeError("The specified 'train_strategy' should be None, tuple, list or Tensor")
|
||||
raise TypeError(
|
||||
"The specified 'train_strategy' should be None, tuple, list or Tensor")
|
||||
|
||||
def _generate_new_optimizer(self, params):
|
||||
"""Generate new optimizer."""
|
||||
if not self.is_lars:
|
||||
opt = self.opt_class(params=params, **self.opt_init_args)
|
||||
else:
|
||||
opt = LARS(self.opt_class(params=params, **self.opt_init_args),
|
||||
**self.lars_init_args)
|
||||
return opt
|
||||
|
||||
|
||||
class _TrainFreezeCell(Cell):
|
||||
"""
|
||||
Gradient freezing training network.
|
||||
|
||||
Args:
|
||||
net (Cell): The training network.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
grad (tuple(tensor)): The gradients of network parameters and inputs.
|
||||
grad_reducer (Cell): Constructs a gradient reducer Cell, which applies communication and average operations on
|
||||
single-process gradient values.
|
||||
use_grad_accumulation (Bool): Whether use grad accumulation.
|
||||
optimizer (Union[Cell]): Optimizer for updating the weights.
|
||||
max_accumulation_step (Number): Max grad accumulation steps. Default: 1.0
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
"""
|
||||
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, optimizer, max_accumulation_step=1):
|
||||
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
|
||||
self.net = net
|
||||
self.grad = grad
|
||||
self.grad_reducer = grad_reducer
|
||||
self.opt = optimizer
|
||||
self.parameters = optimizer.parameters
|
||||
self.sens = sens
|
||||
self.use_grad_accumulation = use_grad_accumulation
|
||||
self.max_accumulation_step = max_accumulation_step
|
||||
if use_grad_accumulation:
|
||||
self.grad_accumulation = GradientAccumulation(
|
||||
self.max_accumulation_step, self.optimizer)
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.net(*inputs)
|
||||
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
||||
grads = self.grad(self.net, self.parameters)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.grad_accumulation(loss, grads)
|
||||
else:
|
||||
loss = F.depend(loss, self.opt(grads))
|
||||
return loss
|
||||
|
||||
|
||||
class GradientFreeze:
|
||||
"""
|
||||
Freezing the gradients of some layers randomly. The number and
|
||||
probability of frozen layers can be configured by users
|
||||
|
||||
Args:
|
||||
param_groups (Union[Tuple, List]): Groups of parameters for gradients freezing training.
|
||||
freeze_type (Int): Strategy of gradients freezing training.
|
||||
freeze_p (FLoat): probability of gradients freezing training.
|
||||
total_steps (Number): Steps of the whole training.
|
||||
|
||||
Examples:
|
||||
>>> gradient_freeze_class = acc.GradientFreeze(10, 1, 0.5, 2000)
|
||||
>>> network, optimizer = gradient_freeze_class.freeze_generate(network, optimizer)
|
||||
"""
|
||||
def __init__(self, param_groups, freeze_type, freeze_p, total_steps):
|
||||
self._param_groups = param_groups
|
||||
self._freeze_type = freeze_type
|
||||
self._freeze_p = freeze_p
|
||||
self._total_steps = total_steps
|
||||
self.grad_reducer = F.identity
|
||||
self._param_processer = ParameterProcess()
|
||||
|
||||
def split_parameters_groups(self, net, freeze_para_groups_number):
|
||||
"""Split parameter groups for gradients freezing training."""
|
||||
grouped_params = []
|
||||
tmp = []
|
||||
for para in net.trainable_params():
|
||||
name = para.name
|
||||
# ensure 'bn' after 'conv' is not split
|
||||
if 'bn' in name or 'bias' in name:
|
||||
tmp.append(para)
|
||||
elif len(tmp) >= 3:
|
||||
grouped_params.append(tmp)
|
||||
tmp = [para]
|
||||
else:
|
||||
tmp.append(para)
|
||||
if tmp:
|
||||
grouped_params.append(tmp)
|
||||
stride = len(grouped_params) // freeze_para_groups_number
|
||||
freeze_grouped_params = [sum(grouped_params[i * stride:], [])
|
||||
for i in range(freeze_para_groups_number)]
|
||||
return freeze_grouped_params
|
||||
|
||||
def generate_freeze_index_sequence(self, parameter_groups_number, freeze_strategy, freeze_p, total_steps):
|
||||
"""Generate index sequence for gradient freezing training."""
|
||||
total_step = total_steps * 1.01
|
||||
# local continuous freezing training strategy, as '00001234'
|
||||
if freeze_strategy == CONTINUOUS_STRATEGY:
|
||||
zero_cnt = int(
|
||||
freeze_p * (parameter_groups_number - 1) / (1 - freeze_p) + 0.5)
|
||||
sub_idx = [0] * zero_cnt + list(range(1, parameter_groups_number))
|
||||
freeze_idxes = []
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_idxes += sub_idx
|
||||
return freeze_idxes
|
||||
# interval freezing training strategy, as '01020304'
|
||||
if freeze_strategy == INTERVAL_STRATEGY:
|
||||
index_all = list(range(1, parameter_groups_number))
|
||||
prob = [x / sum(index_all) for x in index_all]
|
||||
freeze_idxes = [0]
|
||||
zero_cnt = 1
|
||||
freeze_cnt = 0
|
||||
while len(freeze_idxes) < total_step:
|
||||
freeze_p_cur = 1.0 * freeze_cnt / (zero_cnt + freeze_cnt)
|
||||
if freeze_p_cur < 1 - freeze_p:
|
||||
freeze_idxes.append(
|
||||
int(np.random.choice(index_all[::-1], p=prob)))
|
||||
freeze_cnt += 1
|
||||
else:
|
||||
freeze_idxes.append(0)
|
||||
zero_cnt += 1
|
||||
return freeze_idxes
|
||||
raise ValueError(
|
||||
f"Unsupported freezing training strategy '{freeze_strategy}'")
|
||||
|
||||
def freeze_generate(self, network, optimizer):
|
||||
"""Generate freeze network and optimizer."""
|
||||
train_para_groups = self.split_parameters_groups(
|
||||
network, self._param_groups)
|
||||
for i in range(self._param_groups):
|
||||
train_para_groups[i] = self._param_processer.generate_group_params(train_para_groups[i])
|
||||
train_strategy = self.generate_freeze_index_sequence(
|
||||
self._param_groups, self._freeze_type, self._freeze_p, self._total_steps)
|
||||
optimizer = FreezeOpt(optimizer, train_para_groups, train_strategy)
|
||||
|
||||
return network, optimizer
|
||||
|
||||
def freeze_cell(reducer_flag, network, optimizer, sens, grad, use_grad_accumulation, mean=None, degree=None,
|
||||
max_accumulation_step=1):
|
||||
"""Provide freeze network cell."""
|
||||
if reducer_flag:
|
||||
param_processer = ParameterProcess()
|
||||
grad_reducers = (DistributedGradReducer(param_processer.assign_parameter_group(opt.parameters),
|
||||
mean, degree, param_fusion=True) for opt in optimizer.opts)
|
||||
freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, reducer,
|
||||
use_grad_accumulation, opt, max_accumulation_step)
|
||||
for reducer, opt in zip(grad_reducers, optimizer.opts))
|
||||
else:
|
||||
freeze_nets = tuple(_TrainFreezeCell(network, sens, grad, self.grad_reducer,
|
||||
use_grad_accumulation, opt, max_accumulation_step)
|
||||
for opt in optimizer.opts)
|
||||
return freeze_nets
|
||||
|
|
|
@ -93,12 +93,13 @@ class LessBN(Cell):
|
|||
>>> network = acc.LessBN(network)
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
def __init__(self, network, fn_flag=False):
|
||||
super(LessBN, self).__init__()
|
||||
self.network = network
|
||||
self.network.set_acc("less_bn")
|
||||
self.network.update_cell_prefix()
|
||||
self._convert_to_less_bn_net(self.network)
|
||||
if fn_flag:
|
||||
self._convert_to_less_bn_net(self.network)
|
||||
self.network.add_flags(defer_inline=True)
|
||||
|
||||
def _convert_dense(self, subcell):
|
||||
|
@ -110,7 +111,7 @@ class LessBN(Cell):
|
|||
subcell.out_channels,
|
||||
subcell.weight,
|
||||
subcell.bias,
|
||||
subcell.has_bias)
|
||||
False)
|
||||
new_subcell.update_parameters_name(prefix + '.')
|
||||
|
||||
return new_subcell
|
||||
|
|
|
@ -19,6 +19,7 @@ from mindspore.ops import functional as F
|
|||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore.common import Tensor, Parameter, dtype as mstype
|
||||
from .optimizer import _grad_scale, Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
_lars_opt = C.MultitypeFuncGraph("lars_opt")
|
||||
|
||||
|
@ -100,6 +101,7 @@ class LARS(Optimizer):
|
|||
>>> model = Model(net, loss_fn=loss, optimizer=opt_lars, metrics=None)
|
||||
"""
|
||||
|
||||
@opt_init_args_register
|
||||
def __init__(self, optimizer, epsilon=1e-05, coefficient=0.001, use_clip=False,
|
||||
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name):
|
||||
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="fake_param")])
|
||||
|
|
|
@ -37,12 +37,18 @@ from mindspore.nn.learning_rate_schedule import LearningRateSchedule
|
|||
__all__ = ['Optimizer', 'opt_init_args_register']
|
||||
|
||||
def opt_init_args_register(fn):
|
||||
"""Register optimizer init args."""
|
||||
def deco(self, *args, **kwargs):
|
||||
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
arguments = bound_args.arguments
|
||||
arguments.pop('self')
|
||||
arguments.pop('params')
|
||||
if 'params' in arguments.keys():
|
||||
setattr(self, 'init_params', dict({"params": arguments['params']}))
|
||||
arguments.pop('params')
|
||||
if 'optimizer' in arguments.keys():
|
||||
setattr(self, 'init_params', dict({"params": arguments['optimizer'].init_params["params"]}))
|
||||
arguments.pop('optimizer')
|
||||
setattr(self, 'init_args', arguments)
|
||||
fn(self, *args, **kwargs)
|
||||
return deco
|
||||
|
|
|
@ -65,19 +65,6 @@ def _tensors_cast_datatype(datatype, param):
|
|||
"""
|
||||
return F.cast(param, datatype)
|
||||
|
||||
_gradient_accumulation_op = C.MultitypeFuncGraph("gradient_accumulation_op")
|
||||
|
||||
@_gradient_accumulation_op.register("Int64", "Tensor", "Tensor")
|
||||
def _cumulative_grad(accumulation_step, cumulative_grad, grad):
|
||||
"""Apply gradient accumulation to cumulative grad."""
|
||||
return P.AssignAdd()(cumulative_grad, grad / accumulation_step)
|
||||
|
||||
_gradient_clear_op = C.MultitypeFuncGraph("gradient_clear_op")
|
||||
|
||||
@_gradient_clear_op.register("Tensor")
|
||||
def _clear_grad(cumulative_grad):
|
||||
zero_grad = P.ZerosLike()(cumulative_grad)
|
||||
return F.assign(cumulative_grad, zero_grad)
|
||||
|
||||
class WithLossCell(Cell):
|
||||
r"""
|
||||
|
@ -294,32 +281,6 @@ class ForwardValueAndGrad(Cell):
|
|||
return loss, grads
|
||||
|
||||
|
||||
class _TrainFreezeCell(Cell):
|
||||
"""Gradient freezing training network."""
|
||||
def __init__(self, net, sens, grad, grad_reducer, use_grad_accumulation, max_accumulation_step, optimizer):
|
||||
super(_TrainFreezeCell, self).__init__(auto_prefix=False)
|
||||
self.net = net
|
||||
self.grad = grad
|
||||
self.grad_reducer = grad_reducer
|
||||
self.opt = optimizer
|
||||
self.parameters = optimizer.parameters
|
||||
self.sens = sens
|
||||
self.use_grad_accumulation = use_grad_accumulation
|
||||
if use_grad_accumulation:
|
||||
self.grad_accumulation = GradientAccumulation(max_accumulation_step, optimizer)
|
||||
|
||||
def construct(self, *inputs):
|
||||
loss = self.net(*inputs)
|
||||
sens = F.fill(loss.dtype, loss.shape, self.sens)
|
||||
grads = self.grad(self.net, self.parameters)(*inputs, sens)
|
||||
grads = self.grad_reducer(grads)
|
||||
if self.use_grad_accumulation:
|
||||
loss = self.grad_accumulation(loss, grads)
|
||||
else:
|
||||
loss = F.depend(loss, self.opt(grads))
|
||||
return loss
|
||||
|
||||
|
||||
class TrainOneStepCell(Cell):
|
||||
r"""
|
||||
Network training package class.
|
||||
|
@ -395,23 +356,20 @@ class TrainOneStepCell(Cell):
|
|||
self.use_grad_accumulation = False
|
||||
self.grad_accumulation = None
|
||||
if self.use_grad_accumulation:
|
||||
self.grad_accumulation = GradientAccumulation(self.max_accumulation_step, self.optimizer)
|
||||
self.grad_accumulation = acc.GradientAccumulation(self.max_accumulation_step, self.optimizer)
|
||||
if self.reducer_flag:
|
||||
self.mean = _get_gradients_mean()
|
||||
self.degree = _get_device_num()
|
||||
if self.freeze:
|
||||
self.grad_reducers = (DistributedGradReducer(opt.parameters, self.mean, self.degree)
|
||||
for opt in self.optimizer.opts)
|
||||
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, reducer,
|
||||
self.use_grad_accumulation, self.max_accumulation_step, opt)
|
||||
for reducer, opt in zip(self.grad_reducers, self.optimizer))
|
||||
self.freeze_nets = acc.freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens,
|
||||
self.grad, self.use_grad_accumulation, self.mean, self.degree,
|
||||
self.max_accumulation_step)
|
||||
else:
|
||||
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, self.mean, self.degree)
|
||||
else:
|
||||
if self.freeze:
|
||||
self.freeze_nets = tuple(_TrainFreezeCell(self.network, self.sens, self.grad, self.grad_reducer,
|
||||
self.use_grad_accumulation, self.max_accumulation_step, opt)
|
||||
for opt in self.optimizer.opts)
|
||||
self.freeze_nets = acc.freeze_cell(self.reducer_flag, self.network, self.optimizer, self.sens,
|
||||
self.grad, self.use_grad_accumulation, self.max_accumulation_step)
|
||||
self.step = Parameter(Tensor(0, dtype=mstype.int32))
|
||||
|
||||
def construct(self, *inputs):
|
||||
|
@ -726,34 +684,3 @@ class _BroadCastCell(Cell):
|
|||
params = self.broadcast(params)
|
||||
new_params = self.map_(F.partial(_cast_datatype), datatypes, params)
|
||||
return new_params
|
||||
|
||||
class GradientAccumulation(Cell):
|
||||
"""
|
||||
After accumulating the gradients of multiple steps, call to optimize its update.
|
||||
|
||||
Args:
|
||||
max_accumulation_step (int): Steps to accumulate gradients.
|
||||
optimizer(Cell):Optimizer used.
|
||||
"""
|
||||
def __init__(self, max_accumulation_step, optimizer):
|
||||
super(GradientAccumulation, self).__init__()
|
||||
self._max_accumulation_step = max_accumulation_step
|
||||
self.optimizer = optimizer
|
||||
self.weights = optimizer.parameters
|
||||
self.hyper_map = C.HyperMap()
|
||||
self._grad_accumulation = self.weights.clone(prefix="grad_accumulation", init='zeros')
|
||||
self._accumulation_step = Parameter(Tensor(0, dtype=mstype.int32), name="accumulation_step")
|
||||
|
||||
def construct(self, loss, grads):
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_accumulation_op, self._max_accumulation_step),
|
||||
self._grad_accumulation, grads))
|
||||
self._accumulation_step += 1
|
||||
|
||||
if self._accumulation_step >= self._max_accumulation_step:
|
||||
loss = F.depend(loss, self.optimizer(self._grad_accumulation))
|
||||
self._accumulation_step = 0
|
||||
|
||||
if self._accumulation_step == 0:
|
||||
loss = F.depend(loss, self.hyper_map(F.partial(_gradient_clear_op), self._grad_accumulation))
|
||||
|
||||
return loss
|
||||
|
|
|
@ -46,6 +46,18 @@ def _init_allreduce_operators(length, split_indices):
|
|||
return op_list
|
||||
|
||||
|
||||
def _init_allreduce_operators_by_parameters(parameters):
|
||||
""" initialize allreduce communication operators by parameters"""
|
||||
op_list = ()
|
||||
for parameter in parameters:
|
||||
comm_fusion = parameter.comm_fusion
|
||||
op = AllReduce('sum', GlobalComm.WORLD_COMM_GROUP)
|
||||
op.add_prim_attr('fusion', comm_fusion)
|
||||
op.add_prim_attr('index', comm_fusion)
|
||||
op_list = op_list + (op,)
|
||||
return op_list
|
||||
|
||||
|
||||
@reduce_opt.register("Tensor", "Bool", "Function", "Function", "Bool", "Tensor")
|
||||
def _tensors_allreduce(degree, mean, allgather, allreduce, allreduce_filter, grad):
|
||||
"""
|
||||
|
@ -334,7 +346,7 @@ class DistributedGradReducer(Cell):
|
|||
256.0
|
||||
"""
|
||||
|
||||
def __init__(self, parameters, mean=True, degree=None, fusion_type=1):
|
||||
def __init__(self, parameters, mean=True, degree=None, fusion_type=1, param_fusion=False):
|
||||
super(DistributedGradReducer, self).__init__(auto_prefix=False)
|
||||
self.map_ = C.Map()
|
||||
if degree is None:
|
||||
|
@ -351,6 +363,9 @@ class DistributedGradReducer(Cell):
|
|||
if is_parallel_optimizer and split_indices:
|
||||
self.split_fusion = True
|
||||
self.op_list = _init_allreduce_operators(len(parameters), split_indices)
|
||||
elif param_fusion:
|
||||
self.split_fusion = True
|
||||
self.op_list = _init_allreduce_operators_by_parameters(parameters)
|
||||
else:
|
||||
self.split_fusion = False
|
||||
self.allreduce = AllReduce().add_prim_attr('fusion', fusion_type)
|
||||
|
|
|
@ -31,6 +31,7 @@ from ..parallel._ps_context import _is_role_pserver, _is_role_sched
|
|||
from ..nn.metrics import Loss
|
||||
from .. import nn
|
||||
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
|
||||
from ..nn.acc import acc
|
||||
from ..context import ParallelMode
|
||||
from ..parallel._cost_model_context import _set_multi_subgraphs
|
||||
from .dataset_helper import DatasetHelper, connect_network_with_dataset
|
||||
|
@ -124,7 +125,7 @@ class Model:
|
|||
"""
|
||||
|
||||
def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None,
|
||||
eval_indexes=None, amp_level="O0", **kwargs):
|
||||
eval_indexes=None, amp_level="O0", acc_level="O0", **kwargs):
|
||||
self._network = network
|
||||
self._loss_fn = loss_fn
|
||||
self._optimizer = optimizer
|
||||
|
@ -133,6 +134,8 @@ class Model:
|
|||
self._keep_bn_fp32 = True
|
||||
self._check_kwargs(kwargs)
|
||||
self._amp_level = amp_level
|
||||
self._acc_level = acc_level
|
||||
self._eval_network = eval_network
|
||||
self._process_amp_args(kwargs)
|
||||
self._parallel_mode = _get_parallel_mode()
|
||||
self._device_number = _get_device_num()
|
||||
|
@ -141,8 +144,9 @@ class Model:
|
|||
|
||||
self._check_amp_level_arg(optimizer, amp_level)
|
||||
self._check_for_graph_cell(kwargs)
|
||||
self._build_acc_network(kwargs)
|
||||
self._train_network = self._build_train_network()
|
||||
self._build_eval_network(metrics, eval_network, eval_indexes)
|
||||
self._build_eval_network(metrics, self._eval_network, eval_indexes)
|
||||
self._build_predict_network()
|
||||
|
||||
def _check_for_graph_cell(self, kwargs):
|
||||
|
@ -173,7 +177,7 @@ class Model:
|
|||
|
||||
def _check_kwargs(self, kwargs):
|
||||
for arg in kwargs:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
|
||||
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32', 'total_steps']:
|
||||
raise ValueError(f"Unsupported arg '{arg}'")
|
||||
|
||||
def _check_reuse_dataset(self, dataset):
|
||||
|
@ -182,11 +186,25 @@ class Model:
|
|||
if hasattr(dataset, '__model_hash__') and dataset.__model_hash__ != hash(self):
|
||||
raise RuntimeError('The Dataset cannot be bound to different models, please create a new dataset.')
|
||||
|
||||
def _build_acc_network(self, kwargs):
|
||||
"""Build the acc network."""
|
||||
processor = acc.AutoAcc(self._acc_level, kwargs)
|
||||
if self._optimizer is None:
|
||||
logger.warning("In acc mode, the optimizer must be defined.")
|
||||
return
|
||||
if self._eval_network is None:
|
||||
logger.warning("In acc mode, the eval_network must be defined.")
|
||||
return
|
||||
|
||||
self._network, self._optimizer = processor.network_auto_process_train(self._network, self._optimizer)
|
||||
self._eval_network = processor.network_auto_process_eval(self._eval_network)
|
||||
|
||||
def _build_train_network(self):
|
||||
"""Build train network"""
|
||||
network = self._network
|
||||
if self._loss_scale_manager is not None and self._optimizer is None:
|
||||
raise ValueError("Optimizer can not be None when set loss_scale_manager.")
|
||||
|
||||
if self._optimizer:
|
||||
if self._loss_scale_manager_set:
|
||||
network = amp.build_train_network(network,
|
||||
|
|
Loading…
Reference in New Issue