!10221 support grad accumulation for auto parallel

From: @yangzhenzhang
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-01-06 15:27:34 +08:00 committed by Gitee
commit 1c942ce49f
17 changed files with 613 additions and 37 deletions

View File

@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
prim::kPrimMirrorMiniStep);
check_bprop_eliminate_ =
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
reset_defer_inline_ =

View File

@ -51,6 +51,7 @@ class OptimizeIRPassLib {
SubstitutionPtr reset_defer_inline_;
SubstitutionPtr depend_value_elim_;
SubstitutionPtr all_reduce_const_elim_;
SubstitutionPtr mirror_mini_step_elim_;
// Env Item Eliminate
SubstitutionPtr env_get_item_eliminate_;

View File

@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X
class MirrorMiniStepEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
if (cnode == nullptr) {
return nullptr;
}
auto inputs = cnode->inputs();
if (inputs.size() < 2) {
return nullptr;
}
return inputs[1];
}
void Visit(const AnfNodePtr &) override {}
};
// Reset defer_inline flag
class ResetDeferInline : public AnfVisitor {
public:

View File

@ -64,6 +64,7 @@ void ParallelContext::Reset() {
all_reduce_fusion_split_sizes_.clear();
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
pipeline_stage_split_num_ = 1;
grad_accumulation_step_ = 1;
}
void ParallelContext::set_device_num(int64_t device_num) {
@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
grad_accumulation_step_ = grad_accumulation_step;
}
void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }

View File

@ -73,6 +73,9 @@ class ParallelContext {
void set_global_rank(int64_t global_rank);
int64_t global_rank() const { return global_rank_; }
void set_grad_accumulation_step(int64_t grad_accumulation_step);
int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
bool set_parallel_mode(const std::string &parallel_mode);
std::string parallel_mode() const { return parallel_mode_; }
@ -116,6 +119,7 @@ class ParallelContext {
bool loss_repeated_mean_;
int64_t device_num_;
int64_t global_rank_;
int64_t grad_accumulation_step_;
std::string parallel_mode_;
std::string strategy_search_mode_;
int64_t pipeline_stage_split_num_;

View File

@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
}
OperatorVector op_for_weight;
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
OperatorName operator_name = MIRROR_OPERATOR;
ValuePtr attr0_value = MakeValue(group_name);
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
ValuePtr attr2_value = MakeValue(mean_flag);
@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
operator_attrs.push_back(attr1);
operator_attrs.push_back(attr2);
OperatorName operator_name;
if (grad_accumulation_step > 1) {
operator_name = MIRROR_MINI_STEP_OPERATOR;
ValuePtr attr3_value = MakeValue(grad_accumulation_step);
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
operator_attrs.push_back(attr3);
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
} else {
operator_name = MIRROR_OPERATOR;
}
OperatorParams operator_param;
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);

View File

@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward";
constexpr char DTYPE[] = "DType";
constexpr char DEV_NUM[] = "dev_num";
constexpr char MEAN_FLAG[] = "mean_flag";
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
constexpr char TYPES[] = "types";
constexpr char SHAPES[] = "shapes";
constexpr char ACCU_GRADS[] = "accu_grads";
constexpr char GETNEXT_NUM[] = "output_num";
constexpr char SHARED_NAME[] = "shared_name";
constexpr char MIRROR_OP[] = "mirror_op";
@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
constexpr char ALL_REDUCE[] = "AllReduce";
constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator";
constexpr char LOCAL_STEP[] = "local_step";
constexpr char STRIDED_SLICE[] = "StridedSlice";
constexpr char ALL_GATHER[] = "AllGather";
constexpr char REDUCE_SCATTER[] = "ReduceScatter";

View File

@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
MS_LOG(INFO) << "Insert " << instance_name << " success";
}
bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(parameter_node);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
// find the clone parameter
if (!cloned_parameter->has_default()) {
return false;
}
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) {
return false;
}
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
return true;
}
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
const std::string &instance_name, const std::string &weight_name) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(root->manager());
AnfNodePtr local_step_param = nullptr;
AnfNodePtr grad_accu = nullptr;
std::string op_name = op.first;
OperatorArgs arg_forward = op.second;
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
if (grad_accumulation_step > 1) {
bool find_locat_step_node = false;
auto parameters = root->parameters();
for (auto &param : parameters) {
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name() == LOCAL_STEP) {
auto param_users = root->manager()->node_users()[param];
for (auto &user : param_users) {
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
find_locat_step_node = true;
local_step_param = user.first;
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
break;
}
}
break;
}
}
bool find_grad_accu_node = false;
for (auto &param : parameters) {
if (!ParameterIsCloned(param)) {
continue;
}
auto param_ptr = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_ptr);
if (param_ptr->name().find(weight_name) != std::string::npos &&
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
find_grad_accu_node = true;
grad_accu = param;
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
break;
}
}
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
if (!find_locat_step_node || !find_grad_accu_node) {
op_name = MIRROR_OPERATOR;
arg_forward.first.pop_back();
}
}
}
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
MS_EXCEPTION_IF_NULL(pyop_instance);
OperatorParams params = arg_forward.second;
std::vector<AnfNodePtr> new_node_input;
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
} else {
new_node_input = {NewValueNode(pyop_instance), node};
}
if (!params.empty()) {
for (auto &param : params) {
AnfNodePtr val = NewValueNode(param.first.second);
MS_EXCEPTION_IF_NULL(val);
int64_t position = param.second;
(void)new_node_input.insert(new_node_input.begin() + position, val);
}
}
// if the op have 'group' attr, set the rank list name for the op
SetCommunicationOpGroupLabel(new_node_input);
return new_node_input;
}
void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index,
const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name,
const std::string &param_name) {
// insert new node before the node
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope);
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
CNodePtr new_node = func_graph->NewCNode(node_input);
MS_EXCEPTION_IF_NULL(new_node);
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
new_node->set_in_forward_flag(true); // mark forward flag
}
auto new_node_value = node_input[0]->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(new_node_value);
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
new_node_prim->set_instance_name(instance_name);
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
new_node->set_scope(scope);
node_input[0]->set_scope(scope);
manager->SetEdge(node, SizeToLong(index), new_node);
MS_LOG(INFO) << "Insert " << instance_name << " success";
}
// Replace pre_node with pre_node->op
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
const std::string &instance_name) {
@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
}
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size();
FuncGraphPtr func_graph = node->func_graph();
@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
if (!param_node_pair.first) {
continue;
}
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
std::string param_name;
if (param_ptr != nullptr) {
param_name = param_ptr->name();
}
// not a RefKey
if (!param_node_pair.second) {
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
AnfNodePtr pre_node = cnode->input(1);
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
} else {
for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index);
InsertNode(op, node, index, pre_node, func_graph, instance_name);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
auto comm_op = node->input(index)->cast<CNodePtr>();
// add fusion flag
// pipeline mirror would not be set, which should be supported later
@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
}
}
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
MS_EXCEPTION_IF_NULL(distribute_operator);
MS_EXCEPTION_IF_NULL(node);
@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
// insert mirror op
if (!mirror_ops.empty()) {
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
InsertMirrorOps(mirror_ops, node);
InsertMirrorOps(root, mirror_ops, node);
}
// insert virtual div op
if (!virtual_div_op.empty() && is_loss_cnode) {
@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
g_RefMap.clear();
}
bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(parameter_node);
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter);
// find the clone parameter
if (!cloned_parameter->has_default()) {
return false;
}
auto param_value = cloned_parameter->param_info();
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) {
return false;
}
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
return true;
}
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
MS_EXCEPTION_IF_NULL(root);
for (auto &cloned_parameter_node : root->parameters()) {
@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
// insert backward ops
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
}
HandleSpecialNode(distribute_operator, cnode);

View File

@ -82,11 +82,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph);
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
// Generate and init parallel operator
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
const std::vector<Shapes> &shape_list);

View File

@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
.def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.")
.def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.")
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
.def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
.def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,

View File

@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.check_bprop_eliminate_,
irpass.switch_layer_defer_inline_,
irpass.replace_applicator_,
irpass.mirror_mini_step_elim_,
});
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);

View File

@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorM
// Comm ops
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");

View File

@ -21,7 +21,7 @@ from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp,
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self):
return bprop
@bprop_getters.register(_MirrorMiniStepOperator)
def get_bprop_mirror_mini_step_operator(self):
"""
Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
allgather for sparse feature.
"""
group = self.group
dev_num = self.dev_num
mean_flag = self.mean_flag
grad_accumulation_step = self.grad_accumulation_step
all_reduce = AllReduce(group=group)
all_gather = AllGather(group=group)
mul = P.Mul()
cast = P.Cast()
equal = P.Equal()
reshape = P.Reshape()
fusion = 1
if hasattr(self, 'fusion'):
fusion = self.fusion
all_reduce.add_prim_attr("fusion", fusion)
if hasattr(self, 'parameter'):
parameter = self.parameter
all_reduce.add_prim_attr("parameter", parameter)
if self.instance_name:
instance_name = "grad_mirror" + self.instance_name
all_reduce.set_prim_instance_name(instance_name)
def bprop(x, y, z, out, dout):
do_mirror = equal(y, grad_accumulation_step)
do_mirror = reshape(do_mirror, (()))
if mean_flag:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
else:
dx = dout
float_one = F.scalar_cast(1.0, F.dtype(dx))
num = F.scalar_cast(dev_num, F.dtype(dx))
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
float_one = F.scalar_cast(1.0, F.dtype(grad))
num = F.scalar_cast(dev_num, F.dtype(grad))
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
dx = RowTensor(indices, grad, dout.dense_shape)
else:
if F.issubclass_(F.typeof(dout), mstype.tensor):
if do_mirror:
tmp = z + dout
real_grad = all_reduce(tmp)
dx = real_grad - z
else:
dx = dout
else:
if do_mirror:
indices = all_gather(dout.indices)
grad = all_gather(dout.values)
else:
indices = dout.indices
grad = dout.values
dx = RowTensor(indices, grad, dout.dense_shape)
return (dx, zeros_like(y), zeros_like(z))
return bprop
@bprop_getters.register(_VirtualDiv)
def get_bprop_virtual_div_operator(self):
"""Backpropagator for _VirtualDiv, do Div for the divisor."""

View File

@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
Unique, GatherD, Identity)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,

View File

@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer):
mirror = _MirrorOperator()
class _MirrorMiniStepOperator(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
internal use of parallel modules and cannot be called by users.
Args:
group (str): The communication group to work on. Default: None.
dev_num (int): The device number of the group. Default: None.
mean_flag (bool): Whether use mean in backward. Default: None.
grad_accumulation_step (int): The grad accumulation step. Default: None.
"""
@prim_attr_register
def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
self.group = group
self.dev_num = dev_num
self.mean_flag = mean_flag
self.grad_accumulation_step = grad_accumulation_step
def infer_shape(self, x_shape, y_shape, z_shape):
return x_shape
def infer_dtype(self, x_dtype, y_shape, z_shape):
return x_dtype
mirror_mini_step = _MirrorMiniStepOperator()
class _VirtualDiv(PrimitiveWithInfer):
"""
Auto parallel virtual operator. Do nothing in forward, do Div in backward.

View File

@ -249,6 +249,21 @@ class _AutoParallelContext:
return False
return self._context_handle.get_full_batch()
def set_grad_accumulation_step(self, grad_accumulation_step):
"""
Set grad accumulation step.
Args:
grad_accumulation_step (int): The grad accumulation step.
"""
self.check_context_handle()
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
def get_grad_accumulation_step(self):
"""Get grad accumulation step."""
self.check_context_handle()
return self._context_handle.get_grad_accumulation_step()
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
"""
Set strategy checkpoint save path.
@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = {
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = {
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
"full_batch": auto_parallel_context().get_full_batch,
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = {
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
all_reduce_fusion_config=list)
grad_accumulation_step=int, all_reduce_fusion_config=list)
def _set_auto_parallel_context(**kwargs):
"""

View File

@ -0,0 +1,289 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.common.dtype as mstype
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, Momentum, Norm
from mindspore.train import Model
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.initializer import initializer
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.context import ParallelMode
from tests.dataset_mock import MindData
class Dataset(MindData):
def __init__(self, predict, label, length=3):
super(Dataset, self).__init__(size=length)
self.predict = predict
self.label = label
self.index = 0
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.predict, self.label
def reset(self):
self.index = 0
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
@get_square_sum.register("Tensor")
def _get_square_sum(grad):
norm = P.ReduceSum(False)(F.square(grad), ())
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
return norm
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
def _apply_global_norm(clip_norm, global_norm, grad):
grad = grad * clip_norm / global_norm
return grad
class GlobalNorm(Cell):
"""
Calculate the global norm value of given tensors
"""
def __init__(self):
super(GlobalNorm, self).__init__()
self.norm = Norm()
self.hyper_map = C.HyperMap()
def construct(self, grads):
square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
return global_norms
class ClipByGlobalNorm(Cell):
"""
Clip grads by global norm
"""
def __init__(self, clip_norm=1.0):
super(ClipByGlobalNorm, self).__init__()
self.global_norm = GlobalNorm()
self.clip_norm = Tensor([clip_norm], mstype.float32)
self.hyper_map = C.HyperMap()
def construct(self, grads):
global_norm = self.global_norm(grads)
cond = P.GreaterEqual()(global_norm, self.clip_norm)
global_norm = F.select(cond, global_norm, self.clip_norm)
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
return grads
cast = P.Cast()
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
@update_accu_grads.register("Tensor", "Tensor")
def _update_accu_grads(accu_grad, grad):
succ = True
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
zeroslike = P.ZerosLike()
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
@reset_accu_grads.register("Tensor")
def _reset_accu_grads(accu_grad):
succ = True
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale)
class TrainAccumulateStepsWithLossScaleCell(Cell):
"""
Encapsulation class of bert network training.
Append an optimizer to the training network after that the construct
function can be called to create the backward graph. To mimic higher batch size, gradients are
accumulated N times before weight update.
Args:
network (Cell): The training network. Note that loss function should have been added.
optimizer (Optimizer): Optimizer for updating the weights.
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
batch_size * accumulation_steps. Default: 1.
"""
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4):
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = optimizer.parameters
self.optimizer = optimizer
self.accumulation_steps = accumulation_steps
self.one = Tensor(np.array([1]).astype(np.int32))
self.zero = Tensor(np.array([0]).astype(np.int32))
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.reducer_flag = False
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
self.grad_reducer = F.identity
self.degree = 1
if self.reducer_flag:
self.degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.overflow_reducer = F.identity
if self.is_distributed:
self.overflow_reducer = P.AllReduce()
self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = P.LessEqual()
self.logical_or = P.LogicalOr()
self.not_equal = P.NotEqual()
self.select = P.Select()
self.reshape = P.Reshape()
self.hyper_map = C.HyperMap()
self.loss_scale = None
self.loss_scaling_manager = scale_update_cell
if scale_update_cell:
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
@C.add_flags(has_effect=True)
def construct(self, x, b, sens=None):
"""Defines the computation performed."""
weights = self.weights
loss = self.network(x, b)
if sens is None:
scaling_sens = self.loss_scale
else:
scaling_sens = sens
# update accumulation parameters
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
mean_loss = self.accu_loss / self.local_step
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
# alloc status and clear should be right before gradoperation
init = self.alloc_status()
self.clear_before_grad(init)
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
mean_loss = F.depend(mean_loss, accu_succ)
self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
overflow = self.less_equal(self.base, flag_sum)
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
accu_overflow = self.select(overflow, self.one, self.zero)
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
is_accu_step = self.reshape(is_accu_step, (()))
if is_accu_step:
succ = False
else:
# apply grad reducer on grads
grads = self.grad_reducer(self.accu_grads)
scaling = scaling_sens * self.degree * self.accumulation_steps
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
grads = ClipByGlobalNorm()(grads)
accu_overflow = self.overflow_reducer(accu_overflow)
F.control_depend(grads, accu_overflow)
overflow = self.less_equal(self.base, accu_overflow)
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
overflow = F.depend(overflow, accu_succ)
overflow = self.reshape(overflow, (()))
if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
if overflow:
succ = False
else:
succ = self.optimizer(grads)
ret = (mean_loss, overflow, scaling_sens)
return F.depend(ret, succ)
class Net(Cell):
def __init__(self, weight, strategy=None):
super().__init__()
self.mul = P.Mul().shard(strategy)
self.weight = Parameter(weight, "w1")
self.relu = P.ReLU()
self.reduce_sum = P.ReduceSum(keep_dims=True)
def construct(self, x, b):
out = self.mul(x, self.weight)
out = self.relu(out)
out = self.reduce_sum(out)
return out
_x = Tensor(np.ones([2]), dtype=ms.float32)
_b = Tensor(np.ones([16]), dtype=ms.float32)
_w1 = Tensor(np.ones([16]), dtype=ms.float32)
def compile_net(net, grad_accumulation_step):
context.set_context(save_graphs=True)
learning_rate = 0.1
momentum = 0.9
epoch_size = 2
dataset = Dataset(_x, _b)
opt = Momentum(net.trainable_params(), learning_rate, momentum)
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell,
accumulation_steps=grad_accumulation_step)
model = Model(net_wrap)
model.train(epoch_size, dataset, dataset_sink_mode=False)
context.reset_auto_parallel_context()
def test_grad_accumulation():
grad_accumulation_step = 4
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
grad_accumulation_step=grad_accumulation_step)
strategy = ((2,), (2,))
net = Net(_w1, strategy)
compile_net(net, grad_accumulation_step)