!10221 support grad accumulation for auto parallel
From: @yangzhenzhang Reviewed-by: Signed-off-by:
This commit is contained in:
commit
1c942ce49f
|
@ -80,6 +80,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
{prim::kPrimReduceMean, prim::kPrimReduceAll, prim::kPrimReduceSum, prim::kPrimReduceMax, prim::kPrimReduceMin});
|
||||
partial_eliminate_ = MakeSubstitution(std::make_shared<PartialEliminater>(), "partial_eliminate", IsCNodeDup);
|
||||
same_eliminate_ = MakeSubstitution(std::make_shared<SameEliminater>(), "same_eliminate", prim::kPrimSameTypeShape);
|
||||
mirror_mini_step_elim_ = MakeSubstitution(std::make_shared<MirrorMiniStepEliminater>(), "mirror_mini_step_eliminate",
|
||||
prim::kPrimMirrorMiniStep);
|
||||
check_bprop_eliminate_ =
|
||||
MakeSubstitution(std::make_shared<CheckBpropEliminater>(), "check_bprop_eliminate", prim::kPrimCheckBprop);
|
||||
reset_defer_inline_ =
|
||||
|
|
|
@ -51,6 +51,7 @@ class OptimizeIRPassLib {
|
|||
SubstitutionPtr reset_defer_inline_;
|
||||
SubstitutionPtr depend_value_elim_;
|
||||
SubstitutionPtr all_reduce_const_elim_;
|
||||
SubstitutionPtr mirror_mini_step_elim_;
|
||||
|
||||
// Env Item Eliminate
|
||||
SubstitutionPtr env_get_item_eliminate_;
|
||||
|
|
|
@ -155,6 +155,29 @@ class CheckBpropEliminater : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimMirrorMiniStep, X, Y, Z} -> X
|
||||
class MirrorMiniStepEliminater : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
if (!IsPrimitiveCNode(node, prim::kPrimMirrorMiniStep) || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return inputs[1];
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &) override {}
|
||||
};
|
||||
|
||||
// Reset defer_inline flag
|
||||
class ResetDeferInline : public AnfVisitor {
|
||||
public:
|
||||
|
|
|
@ -64,6 +64,7 @@ void ParallelContext::Reset() {
|
|||
all_reduce_fusion_split_sizes_.clear();
|
||||
strategy_search_mode_ = DYNAMIC_PROGRAMMING;
|
||||
pipeline_stage_split_num_ = 1;
|
||||
grad_accumulation_step_ = 1;
|
||||
}
|
||||
|
||||
void ParallelContext::set_device_num(int64_t device_num) {
|
||||
|
@ -80,6 +81,10 @@ void ParallelContext::set_gradients_mean(bool gradients_mean) { gradients_mean_
|
|||
|
||||
void ParallelContext::set_full_batch(bool full_batch) { full_batch_ = full_batch; }
|
||||
|
||||
void ParallelContext::set_grad_accumulation_step(int64_t grad_accumulation_step) {
|
||||
grad_accumulation_step_ = grad_accumulation_step;
|
||||
}
|
||||
|
||||
void ParallelContext::set_gradient_fp32_sync(bool gradient_fp32_sync) { gradient_fp32_sync_ = gradient_fp32_sync; }
|
||||
|
||||
void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; }
|
||||
|
|
|
@ -73,6 +73,9 @@ class ParallelContext {
|
|||
void set_global_rank(int64_t global_rank);
|
||||
int64_t global_rank() const { return global_rank_; }
|
||||
|
||||
void set_grad_accumulation_step(int64_t grad_accumulation_step);
|
||||
int64_t grad_accumulation_step() const { return grad_accumulation_step_; }
|
||||
|
||||
bool set_parallel_mode(const std::string ¶llel_mode);
|
||||
std::string parallel_mode() const { return parallel_mode_; }
|
||||
|
||||
|
@ -116,6 +119,7 @@ class ParallelContext {
|
|||
bool loss_repeated_mean_;
|
||||
int64_t device_num_;
|
||||
int64_t global_rank_;
|
||||
int64_t grad_accumulation_step_;
|
||||
std::string parallel_mode_;
|
||||
std::string strategy_search_mode_;
|
||||
int64_t pipeline_stage_split_num_;
|
||||
|
|
|
@ -285,8 +285,8 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
|
|||
}
|
||||
OperatorVector op_for_weight;
|
||||
bool mean_flag = ParallelContext::GetInstance()->gradients_mean();
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
|
||||
OperatorName operator_name = MIRROR_OPERATOR;
|
||||
ValuePtr attr0_value = MakeValue(group_name);
|
||||
ValuePtr attr1_value = MakeValue(SizeToLong(dev_num));
|
||||
ValuePtr attr2_value = MakeValue(mean_flag);
|
||||
|
@ -300,6 +300,17 @@ OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) {
|
|||
operator_attrs.push_back(attr1);
|
||||
operator_attrs.push_back(attr2);
|
||||
|
||||
OperatorName operator_name;
|
||||
if (grad_accumulation_step > 1) {
|
||||
operator_name = MIRROR_MINI_STEP_OPERATOR;
|
||||
ValuePtr attr3_value = MakeValue(grad_accumulation_step);
|
||||
Attr attr3 = std::make_pair(GRAD_ACCUMULATION_STEP, attr3_value);
|
||||
operator_attrs.push_back(attr3);
|
||||
MS_LOG(INFO) << "The grad accumulation step is " << grad_accumulation_step << ", use mini step mirror";
|
||||
} else {
|
||||
operator_name = MIRROR_OPERATOR;
|
||||
}
|
||||
|
||||
OperatorParams operator_param;
|
||||
OperatorArgs operator_args = std::make_pair(operator_attrs, operator_param);
|
||||
|
||||
|
|
|
@ -146,8 +146,10 @@ constexpr char IS_IN_FORWARD[] = "is_in_forward";
|
|||
constexpr char DTYPE[] = "DType";
|
||||
constexpr char DEV_NUM[] = "dev_num";
|
||||
constexpr char MEAN_FLAG[] = "mean_flag";
|
||||
constexpr char GRAD_ACCUMULATION_STEP[] = "grad_accumulation_step";
|
||||
constexpr char TYPES[] = "types";
|
||||
constexpr char SHAPES[] = "shapes";
|
||||
constexpr char ACCU_GRADS[] = "accu_grads";
|
||||
constexpr char GETNEXT_NUM[] = "output_num";
|
||||
constexpr char SHARED_NAME[] = "shared_name";
|
||||
constexpr char MIRROR_OP[] = "mirror_op";
|
||||
|
@ -171,6 +173,8 @@ constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
|
|||
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";
|
||||
constexpr char ALL_REDUCE[] = "AllReduce";
|
||||
constexpr char MIRROR_OPERATOR[] = "_MirrorOperator";
|
||||
constexpr char MIRROR_MINI_STEP_OPERATOR[] = "_MirrorMiniStepOperator";
|
||||
constexpr char LOCAL_STEP[] = "local_step";
|
||||
constexpr char STRIDED_SLICE[] = "StridedSlice";
|
||||
constexpr char ALL_GATHER[] = "AllGather";
|
||||
constexpr char REDUCE_SCATTER[] = "ReduceScatter";
|
||||
|
|
|
@ -128,6 +128,137 @@ void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const An
|
|||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
// find the clone parameter
|
||||
if (!cloned_parameter->has_default()) {
|
||||
return false;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
bool cloned = param_value->cloned();
|
||||
if (!cloned) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> CreateMirrorInput(const FuncGraphPtr &root, const Operator &op, const AnfNodePtr &node,
|
||||
const std::string &instance_name, const std::string &weight_name) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(root->manager());
|
||||
|
||||
AnfNodePtr local_step_param = nullptr;
|
||||
AnfNodePtr grad_accu = nullptr;
|
||||
std::string op_name = op.first;
|
||||
OperatorArgs arg_forward = op.second;
|
||||
|
||||
int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step();
|
||||
|
||||
if (grad_accumulation_step > 1) {
|
||||
bool find_locat_step_node = false;
|
||||
auto parameters = root->parameters();
|
||||
for (auto ¶m : parameters) {
|
||||
auto param_ptr = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->name() == LOCAL_STEP) {
|
||||
auto param_users = root->manager()->node_users()[param];
|
||||
for (auto &user : param_users) {
|
||||
if (AnfNodeIsPrimitive(user.first, ASSIGN)) {
|
||||
find_locat_step_node = true;
|
||||
local_step_param = user.first;
|
||||
MS_LOG(INFO) << "Find the local step when create mirror, it may be in the mini step grad accumulation mode";
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
bool find_grad_accu_node = false;
|
||||
for (auto ¶m : parameters) {
|
||||
if (!ParameterIsCloned(param)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto param_ptr = param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(param_ptr);
|
||||
if (param_ptr->name().find(weight_name) != std::string::npos &&
|
||||
param_ptr->name().find(ACCU_GRADS) != std::string::npos) {
|
||||
find_grad_accu_node = true;
|
||||
grad_accu = param;
|
||||
MS_LOG(INFO) << "Find the accumulation grad node: " << param_ptr->name();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
||||
if (!find_locat_step_node || !find_grad_accu_node) {
|
||||
op_name = MIRROR_OPERATOR;
|
||||
arg_forward.first.pop_back();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op_name, instance_name);
|
||||
MS_EXCEPTION_IF_NULL(pyop_instance);
|
||||
OperatorParams params = arg_forward.second;
|
||||
|
||||
std::vector<AnfNodePtr> new_node_input;
|
||||
if (op_name == MIRROR_MINI_STEP_OPERATOR) {
|
||||
new_node_input = {NewValueNode(pyop_instance), node, local_step_param, grad_accu};
|
||||
MS_LOG(INFO) << "Insert the local step node and grad accumulation node as the mirror op's input";
|
||||
} else {
|
||||
new_node_input = {NewValueNode(pyop_instance), node};
|
||||
}
|
||||
|
||||
if (!params.empty()) {
|
||||
for (auto ¶m : params) {
|
||||
AnfNodePtr val = NewValueNode(param.first.second);
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
int64_t position = param.second;
|
||||
(void)new_node_input.insert(new_node_input.begin() + position, val);
|
||||
}
|
||||
}
|
||||
|
||||
// if the op have 'group' attr, set the rank list name for the op
|
||||
SetCommunicationOpGroupLabel(new_node_input);
|
||||
return new_node_input;
|
||||
}
|
||||
|
||||
void InsertMirrorNode(const FuncGraphPtr &root, const Operator &op, const CNodePtr &node, size_t index,
|
||||
const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph, const std::string &instance_name,
|
||||
const std::string ¶m_name) {
|
||||
// insert new node before the node
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
ScopePtr scope = node->scope();
|
||||
MS_EXCEPTION_IF_NULL(scope);
|
||||
std::vector<AnfNodePtr> node_input = CreateMirrorInput(root, op, pre_node, instance_name, param_name);
|
||||
CNodePtr new_node = func_graph->NewCNode(node_input);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
if (instance_name.find(SPLIT_SENS) == std::string::npos) {
|
||||
new_node->set_in_forward_flag(true); // mark forward flag
|
||||
}
|
||||
auto new_node_value = node_input[0]->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_node_value);
|
||||
PrimitivePtr new_node_prim = new_node_value->value()->cast<PrimitivePtr>();
|
||||
new_node_prim->set_instance_name(instance_name);
|
||||
new_node_prim->set_attr("keep_value_node_input", MakeValue(true));
|
||||
new_node->set_scope(scope);
|
||||
node_input[0]->set_scope(scope);
|
||||
manager->SetEdge(node, SizeToLong(index), new_node);
|
||||
MS_LOG(INFO) << "Insert " << instance_name << " success";
|
||||
}
|
||||
|
||||
// Replace pre_node with pre_node->op
|
||||
static CNodePtr ReplaceNode(const Operator &op, const AnfNodePtr &pre_node, const FuncGraphPtr &func_graph,
|
||||
const std::string &instance_name) {
|
||||
|
@ -965,7 +1096,7 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
|
|||
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
|
||||
}
|
||||
|
||||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
size_t node_size = node->inputs().size();
|
||||
FuncGraphPtr func_graph = node->func_graph();
|
||||
|
@ -997,6 +1128,13 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
if (!param_node_pair.first) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto param_ptr = param_node_pair.first->cast<ParameterPtr>();
|
||||
std::string param_name;
|
||||
if (param_ptr != nullptr) {
|
||||
param_name = param_ptr->name();
|
||||
}
|
||||
|
||||
// not a RefKey
|
||||
if (!param_node_pair.second) {
|
||||
auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph);
|
||||
|
@ -1028,7 +1166,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
CNodePtr cnode = node->input(index)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AnfNodePtr pre_node = cnode->input(1);
|
||||
InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name);
|
||||
InsertMirrorNode(root, op, cnode, size_t(1), pre_node, func_graph, instance_name, param_name);
|
||||
auto comm_op = cnode->input(size_t(1))->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
// pipeline mirror would not be set, which should be supported later
|
||||
|
@ -1037,7 +1175,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
} else {
|
||||
for (auto &op : backward_op) {
|
||||
AnfNodePtr pre_node = node->input(index);
|
||||
InsertNode(op, node, index, pre_node, func_graph, instance_name);
|
||||
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
|
||||
auto comm_op = node->input(index)->cast<CNodePtr>();
|
||||
// add fusion flag
|
||||
// pipeline mirror would not be set, which should be supported later
|
||||
|
@ -1047,7 +1185,7 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|||
}
|
||||
}
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
||||
MS_EXCEPTION_IF_NULL(distribute_operator);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -1061,7 +1199,7 @@ void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNo
|
|||
// insert mirror op
|
||||
if (!mirror_ops.empty()) {
|
||||
MS_LOG(INFO) << "insert mirror op for " << distribute_operator->name();
|
||||
InsertMirrorOps(mirror_ops, node);
|
||||
InsertMirrorOps(root, mirror_ops, node);
|
||||
}
|
||||
// insert virtual div op
|
||||
if (!virtual_div_op.empty() && is_loss_cnode) {
|
||||
|
@ -1519,28 +1657,6 @@ void CoverSliceShape(const FuncGraphPtr &root) {
|
|||
g_RefMap.clear();
|
||||
}
|
||||
|
||||
bool ParameterIsCloned(const AnfNodePtr ¶meter_node) {
|
||||
MS_EXCEPTION_IF_NULL(parameter_node);
|
||||
auto cloned_parameter = parameter_node->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cloned_parameter);
|
||||
|
||||
// find the clone parameter
|
||||
if (!cloned_parameter->has_default()) {
|
||||
return false;
|
||||
}
|
||||
auto param_value = cloned_parameter->param_info();
|
||||
if (param_value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
bool cloned = param_value->cloned();
|
||||
if (!cloned) {
|
||||
return false;
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "The parameter: " << cloned_parameter->name() << " is cloned";
|
||||
return true;
|
||||
}
|
||||
|
||||
void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
for (auto &cloned_parameter_node : root->parameters()) {
|
||||
|
@ -2459,7 +2575,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|||
|
||||
// insert backward ops
|
||||
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
|
||||
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
|
||||
BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs);
|
||||
}
|
||||
|
||||
HandleSpecialNode(distribute_operator, cnode);
|
||||
|
|
|
@ -82,11 +82,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|||
|
||||
std::pair<bool, CNodePtr> FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph);
|
||||
|
||||
void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node);
|
||||
|
||||
void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
||||
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs);
|
||||
|
||||
// Generate and init parallel operator
|
||||
OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs,
|
||||
const std::vector<Shapes> &shape_list);
|
||||
|
|
|
@ -131,6 +131,8 @@ PYBIND11_MODULE(_c_expression, m) {
|
|||
.def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
|
||||
.def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
|
||||
.def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
|
||||
.def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.")
|
||||
.def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.")
|
||||
.def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
|
||||
.def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
|
||||
.def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
|
||||
|
|
|
@ -143,6 +143,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
|||
irpass.check_bprop_eliminate_,
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.replace_applicator_,
|
||||
irpass.mirror_mini_step_elim_,
|
||||
});
|
||||
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
|
||||
opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true);
|
||||
|
|
|
@ -206,6 +206,7 @@ inline const PrimitivePtr kPrimTensorMove = std::make_shared<Primitive>("TensorM
|
|||
|
||||
// Comm ops
|
||||
inline const PrimitivePtr kPrimMirror = std::make_shared<Primitive>("_MirrorOperator");
|
||||
inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared<Primitive>("_MirrorMiniStepOperator");
|
||||
inline const PrimitivePtr kPrimVirtualDiv = std::make_shared<Primitive>("_VirtualDiv");
|
||||
inline const PrimitivePtr kPrimVirtualDataset = std::make_shared<Primitive>("_VirtualDataset");
|
||||
inline const PrimitivePtr kPrimSend = std::make_shared<Primitive>("Send");
|
||||
|
|
|
@ -21,7 +21,7 @@ from .. import operations as P
|
|||
from ...common.tensor import RowTensor
|
||||
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
|
||||
_GetTensorSlice, _MirrorOperator, ReduceOp,
|
||||
_GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp,
|
||||
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
|
||||
from .grad_base import bprop_getters
|
||||
from ..operations._inner_ops import Send, Receive
|
||||
|
@ -282,6 +282,82 @@ def get_bprop_mirror_operator(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_MirrorMiniStepOperator)
|
||||
def get_bprop_mirror_mini_step_operator(self):
|
||||
"""
|
||||
Backpropagator for _MirrorMiniStepOperator, do allreduce or allgather for the devices in the group,
|
||||
allgather for sparse feature.
|
||||
"""
|
||||
group = self.group
|
||||
dev_num = self.dev_num
|
||||
mean_flag = self.mean_flag
|
||||
grad_accumulation_step = self.grad_accumulation_step
|
||||
|
||||
all_reduce = AllReduce(group=group)
|
||||
all_gather = AllGather(group=group)
|
||||
mul = P.Mul()
|
||||
cast = P.Cast()
|
||||
equal = P.Equal()
|
||||
reshape = P.Reshape()
|
||||
|
||||
fusion = 1
|
||||
if hasattr(self, 'fusion'):
|
||||
fusion = self.fusion
|
||||
all_reduce.add_prim_attr("fusion", fusion)
|
||||
if hasattr(self, 'parameter'):
|
||||
parameter = self.parameter
|
||||
all_reduce.add_prim_attr("parameter", parameter)
|
||||
|
||||
if self.instance_name:
|
||||
instance_name = "grad_mirror" + self.instance_name
|
||||
all_reduce.set_prim_instance_name(instance_name)
|
||||
|
||||
def bprop(x, y, z, out, dout):
|
||||
do_mirror = equal(y, grad_accumulation_step)
|
||||
do_mirror = reshape(do_mirror, (()))
|
||||
if mean_flag:
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if do_mirror:
|
||||
tmp = z + dout
|
||||
real_grad = all_reduce(tmp)
|
||||
dx = real_grad - z
|
||||
else:
|
||||
dx = dout
|
||||
float_one = F.scalar_cast(1.0, F.dtype(dx))
|
||||
num = F.scalar_cast(dev_num, F.dtype(dx))
|
||||
dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
|
||||
else:
|
||||
if do_mirror:
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
else:
|
||||
indices = dout.indices
|
||||
grad = dout.values
|
||||
float_one = F.scalar_cast(1.0, F.dtype(grad))
|
||||
num = F.scalar_cast(dev_num, F.dtype(grad))
|
||||
grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
else:
|
||||
if F.issubclass_(F.typeof(dout), mstype.tensor):
|
||||
if do_mirror:
|
||||
tmp = z + dout
|
||||
real_grad = all_reduce(tmp)
|
||||
dx = real_grad - z
|
||||
else:
|
||||
dx = dout
|
||||
else:
|
||||
if do_mirror:
|
||||
indices = all_gather(dout.indices)
|
||||
grad = all_gather(dout.values)
|
||||
else:
|
||||
indices = dout.indices
|
||||
grad = dout.values
|
||||
dx = RowTensor(indices, grad, dout.dense_shape)
|
||||
|
||||
return (dx, zeros_like(y), zeros_like(z))
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(_VirtualDiv)
|
||||
def get_bprop_virtual_div_operator(self):
|
||||
"""Backpropagator for _VirtualDiv, do Div for the divisor."""
|
||||
|
|
|
@ -35,7 +35,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
|||
SpaceToBatchND, BatchToSpaceND, BroadcastTo, InplaceUpdate, ReverseSequence, EmbeddingLookup,
|
||||
Unique, GatherD, Identity)
|
||||
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
|
||||
_MirrorOperator, ReduceOp, _VirtualDataset,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, ReduceOp, _VirtualDataset,
|
||||
_VirtualDiv, _GetTensorSlice,
|
||||
_HostAllGather, _HostReduceScatter)
|
||||
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
|
||||
|
|
|
@ -567,6 +567,35 @@ class _MirrorOperator(PrimitiveWithInfer):
|
|||
mirror = _MirrorOperator()
|
||||
|
||||
|
||||
class _MirrorMiniStepOperator(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual operator. Do nothing in forward, do all reduce and mean in backward. It is only for
|
||||
internal use of parallel modules and cannot be called by users.
|
||||
|
||||
Args:
|
||||
group (str): The communication group to work on. Default: None.
|
||||
dev_num (int): The device number of the group. Default: None.
|
||||
mean_flag (bool): Whether use mean in backward. Default: None.
|
||||
grad_accumulation_step (int): The grad accumulation step. Default: None.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, group=None, dev_num=None, mean_flag=None, grad_accumulation_step=None):
|
||||
self.group = group
|
||||
self.dev_num = dev_num
|
||||
self.mean_flag = mean_flag
|
||||
self.grad_accumulation_step = grad_accumulation_step
|
||||
|
||||
def infer_shape(self, x_shape, y_shape, z_shape):
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, y_shape, z_shape):
|
||||
return x_dtype
|
||||
|
||||
|
||||
mirror_mini_step = _MirrorMiniStepOperator()
|
||||
|
||||
|
||||
class _VirtualDiv(PrimitiveWithInfer):
|
||||
"""
|
||||
Auto parallel virtual operator. Do nothing in forward, do Div in backward.
|
||||
|
|
|
@ -249,6 +249,21 @@ class _AutoParallelContext:
|
|||
return False
|
||||
return self._context_handle.get_full_batch()
|
||||
|
||||
def set_grad_accumulation_step(self, grad_accumulation_step):
|
||||
"""
|
||||
Set grad accumulation step.
|
||||
|
||||
Args:
|
||||
grad_accumulation_step (int): The grad accumulation step.
|
||||
"""
|
||||
self.check_context_handle()
|
||||
self._context_handle.set_grad_accumulation_step(grad_accumulation_step)
|
||||
|
||||
def get_grad_accumulation_step(self):
|
||||
"""Get grad accumulation step."""
|
||||
self.check_context_handle()
|
||||
return self._context_handle.get_grad_accumulation_step()
|
||||
|
||||
def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):
|
||||
"""
|
||||
Set strategy checkpoint save path.
|
||||
|
@ -492,6 +507,7 @@ _set_auto_parallel_context_func_map = {
|
|||
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().set_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices}
|
||||
|
||||
|
||||
|
@ -509,6 +525,7 @@ _get_auto_parallel_context_func_map = {
|
|||
"strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file,
|
||||
"full_batch": auto_parallel_context().get_full_batch,
|
||||
"enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer,
|
||||
"grad_accumulation_step": auto_parallel_context().get_grad_accumulation_step,
|
||||
"all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices}
|
||||
|
||||
|
||||
|
@ -516,7 +533,7 @@ _get_auto_parallel_context_func_map = {
|
|||
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
|
||||
parameter_broadcast=bool, strategy_ckpt_load_file=str,
|
||||
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
|
||||
all_reduce_fusion_config=list)
|
||||
grad_accumulation_step=int, all_reduce_fusion_config=list)
|
||||
|
||||
def _set_auto_parallel_context(**kwargs):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,289 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell, Momentum, Norm
|
||||
from mindspore.train import Model
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
|
||||
from mindspore.context import ParallelMode
|
||||
|
||||
from tests.dataset_mock import MindData
|
||||
|
||||
|
||||
class Dataset(MindData):
|
||||
def __init__(self, predict, label, length=3):
|
||||
super(Dataset, self).__init__(size=length)
|
||||
self.predict = predict
|
||||
self.label = label
|
||||
self.index = 0
|
||||
self.length = length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.index >= self.length:
|
||||
raise StopIteration
|
||||
self.index += 1
|
||||
return self.predict, self.label
|
||||
|
||||
def reset(self):
|
||||
self.index = 0
|
||||
|
||||
|
||||
get_square_sum = C.MultitypeFuncGraph("get_square_sum")
|
||||
@get_square_sum.register("Tensor")
|
||||
def _get_square_sum(grad):
|
||||
norm = P.ReduceSum(False)(F.square(grad), ())
|
||||
norm = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return norm
|
||||
|
||||
|
||||
apply_global_norm = C.MultitypeFuncGraph("apply_global_norm")
|
||||
@apply_global_norm.register("Tensor", "Tensor", "Tensor")
|
||||
def _apply_global_norm(clip_norm, global_norm, grad):
|
||||
grad = grad * clip_norm / global_norm
|
||||
return grad
|
||||
|
||||
|
||||
class GlobalNorm(Cell):
|
||||
"""
|
||||
Calculate the global norm value of given tensors
|
||||
"""
|
||||
def __init__(self):
|
||||
super(GlobalNorm, self).__init__()
|
||||
self.norm = Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum)))
|
||||
return global_norms
|
||||
|
||||
|
||||
class ClipByGlobalNorm(Cell):
|
||||
"""
|
||||
Clip grads by global norm
|
||||
"""
|
||||
def __init__(self, clip_norm=1.0):
|
||||
super(ClipByGlobalNorm, self).__init__()
|
||||
self.global_norm = GlobalNorm()
|
||||
self.clip_norm = Tensor([clip_norm], mstype.float32)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, grads):
|
||||
global_norm = self.global_norm(grads)
|
||||
cond = P.GreaterEqual()(global_norm, self.clip_norm)
|
||||
global_norm = F.select(cond, global_norm, self.clip_norm)
|
||||
grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads)
|
||||
return grads
|
||||
|
||||
|
||||
cast = P.Cast()
|
||||
update_accu_grads = C.MultitypeFuncGraph("update_accu_grads")
|
||||
|
||||
|
||||
@update_accu_grads.register("Tensor", "Tensor")
|
||||
def _update_accu_grads(accu_grad, grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32)))
|
||||
|
||||
|
||||
zeroslike = P.ZerosLike()
|
||||
reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads")
|
||||
|
||||
|
||||
@reset_accu_grads.register("Tensor")
|
||||
def _reset_accu_grads(accu_grad):
|
||||
succ = True
|
||||
return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad)))
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
reciprocal = P.Reciprocal()
|
||||
|
||||
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * reciprocal(scale)
|
||||
|
||||
|
||||
class TrainAccumulateStepsWithLossScaleCell(Cell):
|
||||
"""
|
||||
Encapsulation class of bert network training.
|
||||
|
||||
Append an optimizer to the training network after that the construct
|
||||
function can be called to create the backward graph. To mimic higher batch size, gradients are
|
||||
accumulated N times before weight update.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network. Note that loss function should have been added.
|
||||
optimizer (Optimizer): Optimizer for updating the weights.
|
||||
scale_update_cell (Cell): Cell to do the loss scale. Default: None.
|
||||
accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size =
|
||||
batch_size * accumulation_steps. Default: 1.
|
||||
"""
|
||||
def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=4):
|
||||
super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.one = Tensor(np.array([1]).astype(np.int32))
|
||||
self.zero = Tensor(np.array([0]).astype(np.int32))
|
||||
self.local_step = Parameter(initializer(0, [1], mstype.int32), name="local_step")
|
||||
self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros')
|
||||
self.accu_overflow = Parameter(initializer(0, [1], mstype.int32))
|
||||
self.accu_loss = Parameter(initializer(0, [1], mstype.float32))
|
||||
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.reducer_flag = False
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
self.grad_reducer = F.identity
|
||||
self.degree = 1
|
||||
if self.reducer_flag:
|
||||
self.degree = get_group_size()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree)
|
||||
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
|
||||
self.overflow_reducer = F.identity
|
||||
if self.is_distributed:
|
||||
self.overflow_reducer = P.AllReduce()
|
||||
self.cast = P.Cast()
|
||||
self.alloc_status = P.NPUAllocFloatStatus()
|
||||
self.get_status = P.NPUGetFloatStatus()
|
||||
self.clear_before_grad = P.NPUClearFloatStatus()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
||||
self.base = Tensor(1, mstype.float32)
|
||||
self.less_equal = P.LessEqual()
|
||||
self.logical_or = P.LogicalOr()
|
||||
self.not_equal = P.NotEqual()
|
||||
self.select = P.Select()
|
||||
self.reshape = P.Reshape()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.loss_scale = None
|
||||
self.loss_scaling_manager = scale_update_cell
|
||||
if scale_update_cell:
|
||||
self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32))
|
||||
|
||||
@C.add_flags(has_effect=True)
|
||||
def construct(self, x, b, sens=None):
|
||||
"""Defines the computation performed."""
|
||||
weights = self.weights
|
||||
loss = self.network(x, b)
|
||||
if sens is None:
|
||||
scaling_sens = self.loss_scale
|
||||
else:
|
||||
scaling_sens = sens
|
||||
|
||||
# update accumulation parameters
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one)
|
||||
self.accu_loss = self.select(is_accu_step, self.accu_loss + loss, loss)
|
||||
mean_loss = self.accu_loss / self.local_step
|
||||
is_accu_step = self.not_equal(self.local_step, self.accumulation_steps)
|
||||
|
||||
# alloc status and clear should be right before gradoperation
|
||||
init = self.alloc_status()
|
||||
self.clear_before_grad(init)
|
||||
grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32))
|
||||
|
||||
accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads)
|
||||
mean_loss = F.depend(mean_loss, accu_succ)
|
||||
|
||||
self.get_status(init)
|
||||
flag_sum = self.reduce_sum(init, (0,))
|
||||
overflow = self.less_equal(self.base, flag_sum)
|
||||
overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow)
|
||||
accu_overflow = self.select(overflow, self.one, self.zero)
|
||||
self.accu_overflow = self.select(is_accu_step, accu_overflow, self.zero)
|
||||
is_accu_step = self.reshape(is_accu_step, (()))
|
||||
|
||||
if is_accu_step:
|
||||
succ = False
|
||||
else:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(self.accu_grads)
|
||||
scaling = scaling_sens * self.degree * self.accumulation_steps
|
||||
grads = self.hyper_map(F.partial(grad_scale, scaling), grads)
|
||||
grads = ClipByGlobalNorm()(grads)
|
||||
accu_overflow = self.overflow_reducer(accu_overflow)
|
||||
F.control_depend(grads, accu_overflow)
|
||||
overflow = self.less_equal(self.base, accu_overflow)
|
||||
accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads)
|
||||
overflow = F.depend(overflow, accu_succ)
|
||||
overflow = self.reshape(overflow, (()))
|
||||
if sens is None:
|
||||
overflow = self.loss_scaling_manager(self.loss_scale, overflow)
|
||||
if overflow:
|
||||
succ = False
|
||||
else:
|
||||
succ = self.optimizer(grads)
|
||||
|
||||
ret = (mean_loss, overflow, scaling_sens)
|
||||
return F.depend(ret, succ)
|
||||
|
||||
|
||||
class Net(Cell):
|
||||
def __init__(self, weight, strategy=None):
|
||||
super().__init__()
|
||||
self.mul = P.Mul().shard(strategy)
|
||||
self.weight = Parameter(weight, "w1")
|
||||
self.relu = P.ReLU()
|
||||
self.reduce_sum = P.ReduceSum(keep_dims=True)
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.mul(x, self.weight)
|
||||
out = self.relu(out)
|
||||
out = self.reduce_sum(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([2]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([16]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([16]), dtype=ms.float32)
|
||||
|
||||
|
||||
def compile_net(net, grad_accumulation_step):
|
||||
context.set_context(save_graphs=True)
|
||||
learning_rate = 0.1
|
||||
momentum = 0.9
|
||||
epoch_size = 2
|
||||
dataset = Dataset(_x, _b)
|
||||
opt = Momentum(net.trainable_params(), learning_rate, momentum)
|
||||
update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000)
|
||||
net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell,
|
||||
accumulation_steps=grad_accumulation_step)
|
||||
model = Model(net_wrap)
|
||||
model.train(epoch_size, dataset, dataset_sink_mode=False)
|
||||
context.reset_auto_parallel_context()
|
||||
|
||||
|
||||
def test_grad_accumulation():
|
||||
grad_accumulation_step = 4
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0,
|
||||
grad_accumulation_step=grad_accumulation_step)
|
||||
strategy = ((2,), (2,))
|
||||
net = Net(_w1, strategy)
|
||||
compile_net(net, grad_accumulation_step)
|
Loading…
Reference in New Issue