!12853 handle the fully split parameter for grad accumulation

From: @yangzhenzhang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-03-05 16:53:32 +08:00 committed by Gitee
commit c12abe7a46
15 changed files with 222 additions and 345 deletions

View File

@ -51,7 +51,7 @@
namespace mindspore {
namespace session {
static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
static std::shared_ptr<std::map<ParamInfoPtr, ParameterPtr>> python_paras;
void ClearPythonParasMap() { python_paras = nullptr; }
namespace {
const int kSummaryGetItem = 2;
@ -106,7 +106,7 @@ bool CheckIfNeedCreateOutputTensor(const AnfNodePtr &node) {
return false;
}
ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
@ -114,7 +114,7 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
if (parameter == nullptr || !parameter->has_default()) {
return nullptr;
}
return parameter->default_param();
return parameter->param_info();
}
tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair,
@ -747,7 +747,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {
@ -1217,7 +1217,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr;
if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
python_paras = std::make_shared<std::map<ParamInfoPtr, ParameterPtr>>();
}
auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) {

View File

@ -88,6 +88,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
prim::kPrimMirrorMiniStep);
mini_step_allgather_replace_ = MakeSubstitution(std::make_shared<MiniStepAllGatherPass>(),
"mini_step_allgather_replace", prim::kPrimMiniStepAllGather);
virtual_add_elim_ = MakeSubstitution(std::make_shared<VirtualAddEliminater>(), "virtual add", prim::kPrimVirtualAdd);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

View File

@ -52,6 +52,7 @@ class OptimizeIRPassLib {
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
SubstitutionPtr virtual_add_elim_;
SubstitutionPtr mini_step_allgather_replace_;
// Env Item Eliminate

View File

@ -175,6 +175,25 @@ class MirrorMiniStepEliminater : public AnfVisitor {
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimVirtualAdd, X, Z} -> X
class VirtualAddEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) {
return nullptr;
}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X}
class MiniStepAllGatherPass : public AnfVisitor {
public:
@ -191,8 +210,15 @@ class MiniStepAllGatherPass : public AnfVisitor {
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
std::string group = attrs[parallel::GROUP]->ToString();
auto fusion = attrs[parallel::FUSION];
parallel::Operator op = parallel::CreateAllGatherOp(group);
std::vector<AnfNodePtr> node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER);
auto prim_anf_node = node_input[0]->cast<ValueNodePtr>();
prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
attrs = prim->attrs();
attrs[parallel::FUSION] = fusion;
prim->SetAttrs(attrs);
auto func_graph = inputs[1]->func_graph();
CNodePtr new_node = func_graph->NewCNode(node_input);
return new_node;

View File

@ -155,13 +155,23 @@ const std::vector<uint32_t> ParallelContext::GetAllReduceFusionSplitSizes(const
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
if (func_graph->has_flag(AUTO_PARALLEL) &&
(!func_graph->has_flag(TRAINING) ||
(ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) {
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (!func_graph->has_flag(TRAINING)) {
init_param_shape_ = false;
MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape";
return;
}
if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) {
init_param_shape_ = false;
MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape";
} else {
param_shapes.clear();
init_param_shape_ = true;
MS_LOG(INFO) << "Init the parameter shape dict";
}
}
@ -171,6 +181,10 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (init_param_shape_) {
return;
}
@ -182,7 +196,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f
Shape shape = iter->second;
std::shared_ptr<abstract::BaseShape> base_shape = std::make_shared<abstract::Shape>(shape);
ptr->set_shape(base_shape);
MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape;
}
// Clear param_shapes before training in auto-parallel or semi-auto-parallel mode
@ -192,6 +206,10 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(param_node);
MS_EXCEPTION_IF_NULL(ptr);
if (!func_graph->has_flag(AUTO_PARALLEL)) {
return;
}
if (!init_param_shape_) {
return;
}

View File

@ -110,6 +110,8 @@ constexpr char STRIDES[] = "strides";
constexpr char GROUP[] = "group";
constexpr char FUSION[] = "fusion";
constexpr char DO_MIRROR[] = "do_mirror";
constexpr char RECOMPUTE[] = "recompute";
constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op";
constexpr char NUM_SAMPLED[] = "num_sampled";
constexpr char NUM_TRUE[] = "num_true";
constexpr char SEED[] = "seed";

View File

@ -97,6 +97,27 @@ void SetMiniStepOpDoMirrorLabel(std::vector<AnfNodePtr> new_node_input, bool acc
prim->SetAttrs(attrs);
}
void SetAllReduceRecomputeFlag(const std::vector<AnfNodePtr> &new_node_input, const CNodePtr &node) {
if (new_node_input.empty()) {
return;
}
auto prim_anf_node = new_node_input[0]->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
auto attrs = prim->attrs();
auto anf_node = node->input(0)->cast<ValueNodePtr>();
auto prim_node = GetValueNode<PrimitivePtr>(anf_node);
MS_EXCEPTION_IF_NULL(prim_node);
auto node_attrs = prim_node->attrs();
if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue<bool>(node_attrs[RECOMPUTE_COMM_OP])) {
attrs[RECOMPUTE] = MakeValue<bool>(false);
prim->SetAttrs(attrs);
MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString();
}
}
std::vector<AnfNodePtr> CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) {
MS_EXCEPTION_IF_NULL(node);
OperatorArgs arg_forward = op.second;
@ -353,6 +374,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
std::string instance_name_base = FORWARD_OP;
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
SetAllReduceRecomputeFlag(forward_input, node_to_insert);
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
MS_EXCEPTION_IF_NULL(forward_node);
ScopePtr scope = node->scope();
@ -1165,7 +1187,14 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// not a RefKey
if (!param_node_pair.second) {
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
std::string mirror_op_name;
if (grad_accumulation_step > 1) {
mirror_op_name = MIRROR_MINI_STEP_OPERATOR;
} else {
mirror_op_name = MIRROR_OPERATOR;
}
auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph);
// if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead
if (next_cnode.first) {
MS_EXCEPTION_IF_NULL(next_cnode.second);
@ -1743,6 +1772,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
std::string param_name = cloned_parameter_node->cast<ParameterPtr>()->name();
if (cloned_from_parameter->user_data<TensorLayout>() == nullptr) {
MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it";
continue;
}
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
@ -3298,6 +3331,97 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) {
}
}
static bool IsFullySplitParameter(const ParameterPtr &param_ptr) {
auto tensor_layout = param_ptr->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
return false;
}
auto dev_mat_shape = tensor_layout->device_arrangement().array();
auto tensor_map = tensor_layout->tensor_map().array();
int64_t rank = g_device_manager->global_rank();
RankList rank_list = g_device_manager->GetDeviceListInThisStage();
DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape);
RankList group_devices;
if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) {
MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout";
return false;
}
if (group_devices.size() == 1) {
MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split";
return true;
}
return false;
}
static AnfNodePtr FindGradAccuParameter(const std::vector<AnfNodePtr> &parameters, const std::string &name) {
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == name) {
continue;
}
if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) {
return parameter;
}
}
return nullptr;
}
static void InsertFullySplitParamGradAccu(const std::pair<AnfNodePtr, int> &node_user,
const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) {
auto cnode = node_user.first->cast<CNodePtr>();
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node";
return;
}
OperatorAttrs attrs;
auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu");
auto value_node = NewValueNode(py_instance);
std::vector<AnfNodePtr> virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter};
auto graph = cnode->func_graph();
auto virtual_node = graph->NewCNode(virtual_node_input);
manager->SetEdge(cnode, node_user.second, virtual_node);
}
static void HandleFullySplitParameters(const FuncGraphPtr &root) {
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) {
return;
}
auto parameters = root->parameters();
auto node_users_map = root->manager()->node_users();
for (auto &parameter : parameters) {
auto param_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (!IsFullySplitParameter(param_ptr)) {
continue;
}
auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name());
if (!accu_parameter) {
continue; // some parameters no need to handle, such as itself or lr
}
auto node_users = node_users_map[parameter];
for (auto &user : node_users) {
auto node = user.first;
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->in_forward_flag()) {
continue;
}
InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter);
MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name();
break; // only need to insert once, if the parameter has many users
}
}
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
@ -3390,6 +3514,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
MS_LOG(EXCEPTION) << "Save group info failed";
}
// handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter
HandleFullySplitParameters(root);
DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once

View File

@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
irpass.virtual_add_elim_,
irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_,
});

View File

@ -307,6 +307,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOper
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared<Primitive>("_MiniStepAllGather");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualAdd = std::make_shared<Primitive>("_VirtualAdd");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
inline const PrimitivePtr kPrimReceive = std::make_shared<Primitive>("Receive");

View File

@ -22,7 +22,7 @@ from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -108,6 +108,14 @@ def get_bprop_receive(self):
return bprop
@bprop_getters.register(_VirtualAdd)
def get_bprop_virtual_add(self):
"""Generate bprop for _VirtualAdd"""
def bprop(x, grad_accu, out, dout):
return (dout + grad_accu, zeros_like(grad_accu))
return bprop
@bprop_getters.register(Broadcast)
def get_bprop_broad_cast(self):
"""Generate bprop for Broadcast."""
@ -168,13 +176,13 @@ def get_bprop_mini_step_all_gather(self):
def bprop(x, z, out, dout):
if do_mirror:
if mean_flag:
tmp = z + dout
grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
dx = F.tensor_mul(dx, scale)
else:
tmp = z + dout
grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
grad = all_reduce(z)
dx = split(grad)[rank]
else:
dx = dout
@ -326,7 +334,6 @@ def get_bprop_mirror_mini_step_operator(self):
mean_flag = self.mean_flag
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
@ -345,8 +352,8 @@ def get_bprop_mirror_mini_step_operator(self):
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
@ -354,32 +361,17 @@ def get_bprop_mirror_mini_step_operator(self):
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = RowTensor(indices, grad, dout.dense_shape)
dx = zeros_like(x) # The grad accumulation do not support row tensor now
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
z = F.depend(z, F.assign_add(z, dout))
real_grad = all_reduce(z)
dx = real_grad
else:
dx = dout
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape)
dx = zeros_like(x) # The grad accumulation do not support row tensor now
return (dx, zeros_like(z))
return bprop

View File

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
Unique, GatherD, Identity, Range)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_VirtualDiv, _GetTensorSlice, _VirtualAdd,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

View File

@ -653,6 +653,19 @@ class _VirtualDiv(PrimitiveWithInfer):
virtual_div = _VirtualDiv()
class _VirtualAdd(PrimitiveWithInfer):
"""Auto parallel virtual operator. Do nothing in forward, do Add in backward."""
@prim_attr_register
def __init__(self):
"""init"""
def infer_shape(self, x_shape, y_shape):
return x_shape
def infer_dtype(self, x_dtype, y_dtype):
return x_dtype
class _VirtualDataset(PrimitiveWithInfer):
"""
Auto parallel virtual dataset operator.

View File

@ -25,6 +25,7 @@ from mindspore.common.initializer import TruncatedNormal, initializer, Normal
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class LayerNorm(nn.Cell):
"""
Layer Normalization

View File

@ -47,6 +47,7 @@ def test_get_parameter_layout():
net = Net(strategy1, strategy2, weight)
net.set_auto_parallel()
net.set_train()
exe = me._executor
exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1]

View File

@ -1,307 +0,0 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.train import Model
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore.context import ParallelMode
from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm
from mindspore.parallel._utils import _get_device_num
from tests.dataset_mock import MindData
class Dataset(MindData):
def __init__(self, predict, label, length=3):
super(Dataset, self).__init__(size=length)
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label
def reset(self):
self.index = 0
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
norm = P.ReduceSum(False)(F.square(grad), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
class GlobalNorm(Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
return global_norms
class ClipByGlobalNorm(Cell):
"""
Clip grads by global norm
"""
def __init__(self, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm()
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
def construct(self, grads):
global_norm = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm, self.clip_norm)
global_norm = F.select(cond, global_norm, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads
cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
zeroslike = P.ZerosLike()
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class TrainAccumulateStepsWithLossScaleCell(Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are
accumulated N times before weight update.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None):
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.accu = False
self.is_accu_step = Tensor(np.array([self.accu]))
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step")
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.reducer_flag = False
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.degree = 1
self.grad_reducer = F.identity
if self.reducer_flag:
self.degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self, x, b, sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(x, b)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
if self.is_accu_step and self.accumulation_steps > 1:
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
loss = F.depend(loss, accu_succ)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero)
if self.is_accu_step:
succ = False
else:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
grads = ClipByGlobalNorm()(grads)
accu_overflow = self.overflow_reducer(accu_overflow)
F.control_depend(grads, accu_overflow)
overflow = self.less_equal(self.base, accu_overflow)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
overflow = F.depend(overflow, accu_succ)
overflow = self.reshape(overflow, (()))
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (loss, overflow, scaling_sens)
return F.depend(ret, succ)
class Net(Cell):
def __init__(self, weight, strategy=None):
super().__init__()
self.mul = P.Mul().shard(strategy)
self.weight = Parameter(weight, "w1")
self.relu = P.ReLU()
self.reduce_sum = P.ReduceSum(keep_dims=True)
def construct(self, x, b):
out = self.mul(x, self.weight)
out = self.relu(out)
out = self.reduce_sum(out)
return out
_x = Tensor(np.ones([2]), dtype=ms.float32)
_b = Tensor(np.ones([16]), dtype=ms.float32)
_w1 = Tensor(np.ones([16]), dtype=ms.float32)
def compile_net(net):
context.set_context(enable_sparse=False)
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
dataset = Dataset(_x, _b)
opt = Momentum(net.trainable_params(), learning_rate, momentum)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell)
model = Model(net_wrap)
model.train(epoch_size, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
def test_grad_accumulation_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net)
def test_grad_accu_and_opt_shard_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=True)
compile_net(net)
def test_grad_accumulation_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)
def test_grad_accu_and_opt_shard_not_accu():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True)
strategy = ((2,), (2,))
net = Net(_w1, strategy).add_flags_recursive(accu=False)
compile_net(net)