diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 0aabf8fa4a6..9e931cecc5e 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -106,7 +106,9 @@ OptimizeIRPassLib::OptimizeIRPassLib() { prim::kPrimMirrorMiniStep); mini_step_allgather_replace_ = MakeSubstitution(std::make_shared(), "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); - virtual_add_elim_ = MakeSubstitution(std::make_shared(), "virtual_add", prim::kPrimVirtualAdd); + micro_step_allgather_replace_ = MakeSubstitution(std::make_shared(), + "micro_step_allgather_replace", prim::kPrimMicroStepAllGather); + virtual_add_elim_ = MakeSubstitution(std::make_shared(), "virtual add", prim::kPrimVirtualAdd); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index bcae681f1ee..b257eaf22eb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -60,6 +60,7 @@ class OptimizeIRPassLib { SubstitutionPtr mirror_mini_step_elim_; SubstitutionPtr virtual_add_elim_; SubstitutionPtr mini_step_allgather_replace_; + SubstitutionPtr micro_step_allgather_replace_; // Env Item Eliminate SubstitutionPtr env_get_item_eliminate_; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 3ea013d4340..b74d0b32874 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -300,6 +300,39 @@ class MiniStepAllGatherPass : public AnfVisitor { void Visit(const AnfNodePtr &) override {} }; +// {prim::kPrimMicroStepAllGather, X, Z} -> {prim::kPrimAllGather, X} +class MicroStepAllGatherPass : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimMicroStepAllGather) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + auto prim = GetValueNode(node->cast()->input(0)); + MS_EXCEPTION_IF_NULL(prim); + auto attrs = prim->attrs(); + std::string group = attrs[parallel::GROUP]->ToString(); + auto fusion = attrs[parallel::FUSION]; + parallel::Operator op = parallel::CreateAllGatherOp(group); + std::vector node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); + auto prim_anf_node = node_input[0]->cast(); + prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + attrs = prim->attrs(); + attrs[parallel::FUSION] = fusion; + prim->SetAttrs(attrs); + auto func_graph = inputs[1]->func_graph(); + CNodePtr new_node = func_graph->NewCNode(node_input); + return new_node; + } + + void Visit(const AnfNodePtr &) override {} +}; + // Reset defer_inline flag class ResetDeferInline : public AnfVisitor { public: diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h index a1cb6fe15ca..25342eef82a 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.h @@ -36,7 +36,6 @@ constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_INHERENT_TIME = 0 constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_ALLREDUCE_BANDWIDTH = 0.1; constexpr double DEFAULT_COST_MODEL_ALLREDUCE_FUSION_COMPUTATION_TIME_PARAMETER = 0.1; -constexpr char PARAMETER[] = "parameter"; const uint64_t MAX_RECURSIVE_CALL_TIMES = 100; class AllreduceFusion { public: diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc index 5b6afa97231..57f3d3ca7e4 100755 --- a/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/pipeline_split_utils.cc @@ -87,14 +87,12 @@ void SetStridedSliceStrategy(const AnfNodePtr &node) { void InsertVirtualAssignAdd(const std::pair &node_user, const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) { auto cnode = node_user.first->cast(); - if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag() || - ((IsPrimitiveCNode(node_user.first, prim::kPrimSend) || IsPrimitiveCNode(node_user.first, prim::kPrimDepend)) && - ParallelContext::GetInstance()->enable_parallel_optimizer())) { + if (IsPrimitiveCNode(cnode, prim::kPrimReceive) || !cnode->in_forward_flag()) { return; } auto prim = GetCNodePrimitive(cnode); if (prim == nullptr) { - MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAd"; + MS_LOG(WARNING) << cnode->DebugString() << " can not insert _VirtualAssignAdd."; return; } OperatorAttrs attrs; @@ -154,10 +152,12 @@ void HandleReceiveParam(const FuncGraphPtr &root, const std::vector auto node_users = node_users_map[node]; for (auto &temp_user : node_users) { auto temp_node = temp_user.first; + // Micro virtual operator might be inserted after cast if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { temp_node = node_users_map[temp_node].begin()->first; } - if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { + if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) || + IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) { auto node_set = node_users_map[temp_node]; for (auto &node_user : node_set) { InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); @@ -182,10 +182,12 @@ void AddVirtualAssignAdd(const FuncGraphPtr &root) { auto node_users = node_users_map[parameter]; for (auto &temp_user : node_users) { auto temp_node = temp_user.first; + // Micro virtual operator might be inserted after cast if (IsPrimitiveCNode(temp_node, prim::kPrimCast)) { temp_node = node_users_map[temp_node].begin()->first; } - if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep)) { + if (IsPrimitiveCNode(temp_node, prim::kPrimMirrorMicroStep) || + IsPrimitiveCNode(temp_node, prim::kPrimMicroStepAllGather)) { auto node_set = node_users_map[temp_node]; for (auto &node_user : node_set) { InsertVirtualAssignAdd(node_user, root->manager(), accu_parameter); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc index 8e06415863e..c841f23fcbc 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.cc @@ -428,6 +428,26 @@ Operator CreateMiniStepAllGatherOp(const std::string &group) { return op; } +Operator CreateMicroStepAllGatherOp(const std::string &group) { + bool mean_flag = ParallelContext::GetInstance()->gradients_mean(); + + OperatorName operator_name = MICRO_STEP_ALL_GATHER; + ValuePtr attr0_value = MakeValue(group); // group + Attr attr0 = std::make_pair(GROUP, attr0_value); + ValuePtr attr1_value = MakeValue(mean_flag); // mean_flag + Attr attr1 = std::make_pair(MEAN_FLAG, attr1_value); + OperatorAttrs operator_attrs; + operator_attrs.push_back(attr0); + operator_attrs.push_back(attr1); + + OperatorParams operator_param; + OperatorArgs operator_arg = std::make_pair(operator_attrs, operator_param); + + Operator op = std::make_pair(operator_name, operator_arg); + MS_LOG(INFO) << "Create MICRO_STEP_ALL_GATHER success, the group is " << group; + return op; +} + // use for get tensor slice Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index e1beef52045..f503fe372e4 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -299,6 +299,7 @@ Operator CreateReduceScatterOp(const std::string &reduce_op, const std::string & Operator CreateAllGatherOp(const std::string &group); Operator CreateMiniStepAllGatherOp(const std::string &group); void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr ¶m_node); +Operator CreateMicroStepAllGatherOp(const std::string &group); void AddCommOpMeanFlag(const CNodePtr &comm_node); void AddCommOpParamFlag(const CNodePtr &comm_node); Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index a27f7f34197..84ff1fe6459 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -217,6 +217,7 @@ constexpr char LOCAL_STEP[] = "local_step"; constexpr char STRIDED_SLICE[] = "StridedSlice"; constexpr char ALL_GATHER[] = "AllGather"; constexpr char MINI_STEP_ALL_GATHER[] = "_MiniStepAllGather"; +constexpr char MICRO_STEP_ALL_GATHER[] = "_MicroStepAllGather"; constexpr char REDUCE_SCATTER[] = "ReduceScatter"; constexpr char HOST_REDUCE_SCATTER[] = "_HostReduceScatter"; constexpr char EMBEDDING_LOOKUP[] = "EmbeddingLookup"; @@ -383,6 +384,7 @@ constexpr char VIRTUAL_ACCU_GRAD[] = "_VirtualAccuGrad"; constexpr char ACCU_GRAD[] = "accu_grad"; constexpr char PARAMETER_START[] = "parameter_start"; constexpr char PARAM_INDEX[] = "param_index"; +constexpr char PARAMETER[] = "parameter"; // Parallel don't care constexpr char STRING_EQUAL[] = "string_equal"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 71103debfca..ff2b05f4715 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -199,7 +199,8 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat if (op_name == MIRROR_MINI_STEP_OPERATOR) { op_name = MIRROR_OPERATOR; arg_forward.first.pop_back(); - } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR) { + } else if (op_name == MINI_STEP_ALL_GATHER || op_name == MIRROR_MICRO_STEP_OPERATOR || + op_name == MICRO_STEP_ALL_GATHER) { MS_LOG(EXCEPTION) << "You should define `accu_grads` when use " << op_name << " parameter:" << weight_name; } } @@ -211,7 +212,7 @@ std::vector CreateMirrorInput(const FuncGraphPtr &root, const Operat std::vector new_node_input; if (op_name == MIRROR_MINI_STEP_OPERATOR || op_name == MINI_STEP_ALL_GATHER || - op_name == MIRROR_MICRO_STEP_OPERATOR) { + op_name == MIRROR_MICRO_STEP_OPERATOR || op_name == MICRO_STEP_ALL_GATHER) { new_node_input = {NewValueNode(pyop_instance), node, grad_accu}; MS_LOG(INFO) << "Insert the grad accumulation node as the mirror op's input"; } else { @@ -1117,6 +1118,15 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap return std::make_pair(nullptr, false); } +// only used for FindCNode +CNodePtr SkipTrivialNodesMoveDown(FuncGraphManagerPtr manager, CNodePtr node) { + MS_EXCEPTION_IF_NULL(node); + while (IsInTrivialNodeList(node) || IsSomePrimitive(node, LOAD)) { + node = manager->node_users()[node].begin()->first->cast(); + } + return node; +} + std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(anode); MS_EXCEPTION_IF_NULL(anode->func_graph()); @@ -1130,6 +1140,9 @@ std::pair FindCNode(const AnfNodePtr &anode, const std::string & if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; } + if (ParallelContext::GetInstance()->enable_parallel_optimizer()) { + use_apply = SkipTrivialNodesMoveDown(manager, use_apply); + } ValueNodePtr prim_anf_node = use_apply->input(0)->cast(); MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr node_prim = prim_anf_node->value()->cast(); @@ -1202,7 +1215,7 @@ static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &no } // only used for InsertMirrorOps -CNodePtr SkipTrivialNodes(CNodePtr node) { +CNodePtr SkipTrivialNodesMoveUp(CNodePtr node) { MS_EXCEPTION_IF_NULL(node); while (!IsSomePrimitive(node, LOAD)) { if (IsInTrivialNodeList(node) || IsInAllGatherNodeList(node)) { @@ -1287,7 +1300,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons // assume Load is inserted next to parameter // skip Load moving up and insert mirror next to the parameter if (pre_node->cast()) { - CNodePtr load_node = SkipTrivialNodes(node->input(index)->cast()); + CNodePtr load_node = SkipTrivialNodesMoveUp(node->input(index)->cast()); manager->SetEdge(load_node, 1, next_cnode.second); } else { manager->SetEdge(node, static_cast(index), next_cnode.second); @@ -1306,7 +1319,7 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons if (pre_node->cast() && (InsertMirrorBeforeCast(node, index) || is_shared_param)) { // assume Load is inserted next to parameter // skip Load moving up and insert mirror next to the parameter - CNodePtr load_node = SkipTrivialNodes(pre_node->cast()); + CNodePtr load_node = SkipTrivialNodesMoveUp(pre_node->cast()); InsertNode(op, load_node, 1, load_node->input(1), func_graph, mirror_op_name, param_name, root); auto comm_op = load_node->input(1)->cast(); // add fusion flag @@ -1706,6 +1719,8 @@ static void InsertAllGatherOp(const FuncGraphPtr &root, const std::string &group auto param_name = node->cast()->name(); if (op_name == MINI_STEP_ALL_GATHER) { op = CreateMiniStepAllGatherOp(group); + } else if (op_name == MICRO_STEP_ALL_GATHER) { + op = CreateMicroStepAllGatherOp(group); } else { op = CreateAllGatherOp(group); } @@ -1733,9 +1748,12 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(manager); int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + int32_t split_stage_num = ParallelContext::GetInstance()->pipeline_stage_split_num(); std::string op_name; if (grad_accumulation_step > 1) { op_name = MINI_STEP_ALL_GATHER; + } else if (split_stage_num > 1) { + op_name = MICRO_STEP_ALL_GATHER; } else { op_name = ALL_GATHER; } @@ -1744,7 +1762,7 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & for (auto ¶m_pair : param_sub_set) { auto cnode = param_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->in_forward_flag()) { + if (cnode->in_forward_flag() && !IsPrimitiveCNode(cnode, prim::kPrimReceive)) { OperatorInfoPtr distribute_operator = cnode->user_data(); if (distribute_operator == nullptr) { MS_LOG(DEBUG) << "Parallel optimizer: " << GetPrimName(cnode) << " 's OperatorInfoPtr is nullptr"; @@ -1759,6 +1777,8 @@ static void ApplyParallelOptOnParam(const FuncGraphPtr &root, const AnfNodePtr & manager->SetEdge(cnode, SizeToLong(param_pair.second), next_cnode.second); MS_LOG(INFO) << "Parallel optimizer is shared between " << parameter->ToString() << " and " << GetPrimName(cnode); + } else { + MS_LOG(ERROR) << "Can not find the shared AllGather with multiple node users."; } } else { // insert allgather operator between shard parameter and cnode @@ -2852,12 +2872,14 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorHasPrimalAttr(PIPELINE_PARAM)) { + // insert forward ops + InsertForwardOps(distribute_operator, cnode); + // insert redistribution ops + StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode); + } // insert backward ops if (has_backward) { BackwardCommunication(root, distribute_operator, cnode, sens_loss_pairs); @@ -2873,7 +2895,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorisa()) { auto cnode = node->cast(); - if (!IsParallelCareNode(cnode) || !cnode->has_user_data() || IsSomePrimitive(cnode, RECEIVE)) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data() || IsSomePrimitive(cnode, RECEIVE) || + IsSomePrimitive(cnode, SEND)) { continue; } @@ -2922,7 +2945,8 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector("FusedMulAd inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared("_MiniStepAllGather"); +inline const PrimitivePtr kPrimMicroStepAllGather = std::make_shared("_MicroStepAllGather"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); inline const PrimitivePtr kPrimVirtualAdd = std::make_shared("_VirtualAdd"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index 7130cd396ae..39b62321a3a 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -31,7 +31,8 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", "stop_gradient", "UpdateState", "Load"}; -static const std::set ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather}; +static const std::set ALLGATHER_NODE_LIST_ = {prim::kPrimAllGather, prim::kPrimMiniStepAllGather, + prim::kPrimMicroStepAllGather}; static const std::set TRIVIAL_NODE_LIST_ = {prim::kPrimCast, prim::kPrimDepend}; // clang-format on diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 0018b5b26c9..1e8dbbd6099 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -16,7 +16,7 @@ from types import FunctionType, MethodType from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean, - _get_parallel_mode) + _get_parallel_mode, _get_enable_parallel_optimizer) from mindspore.context import ParallelMode, get_auto_parallel_context from mindspore._checkparam import Validator as validator from mindspore import ops, nn @@ -538,6 +538,7 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell): super(_TrainPipelineAccuStepCell, self).__init__(network, optimizer, sens) self.accu_grads = self.weights.clone(prefix="accu_grads", init="zeros") self.hyper_map = ops.HyperMap() + self.opt_shard = _get_enable_parallel_optimizer() def construct(self, *inputs): weights = self.weights @@ -545,7 +546,10 @@ class _TrainPipelineAccuStepCell(TrainOneStepCell): sens = ops.Fill()(ops.DType()(loss), ops.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) accu_grads = ops.depend(self.accu_grads, grads) - succ = self.optimizer(accu_grads) + if self.opt_shard: + succ = self.optimizer(grads) + else: + succ = self.optimizer(accu_grads) clear = self.hyper_map(_pipeline_clear_grad, accu_grads, grads) loss = ops.depend(loss, succ, clear) return loss diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 7e59f92e53d..d631ff3fe6a 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -18,13 +18,14 @@ from mindspore import Tensor import mindspore.common.dtype as mstype from mindspore.ops import functional as F from mindspore.communication import get_rank, get_group_size +from mindspore.parallel._utils import _get_enable_parallel_optimizer from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap, - _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator) + _VirtualAssignAdd, _VirtualAccuGrad, _MirrorMicroStepOperator, _MicroStepAllGather) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive @@ -102,10 +103,14 @@ def get_bprop_receive(self): depend = P.Depend() cast = P.Cast() out_tensor = Tensor(0.0, mstype.float16) + is_opt_shard = _get_enable_parallel_optimizer() def bprop(x, out, dout): send_out = receive_grad(dout) - dx = depend(cast(out_tensor, F.dtype(x)), send_out) + if is_opt_shard: + dx = depend(F.zeros_like(x), send_out) + else: + dx = depend(cast(out_tensor, F.dtype(x)), send_out) return (dx,) return bprop @@ -174,6 +179,7 @@ def get_bprop_mirror_micro_step_operator(self): if "parameter_micro" in self.get_attr_dict(): assign.add_prim_attr("parameter_micro", 0) out_tensor = Tensor(1.0, mstype.float16) + opt_shard = _get_enable_parallel_optimizer() def bprop(x, z, out, dout): real_grad = z @@ -188,6 +194,8 @@ def get_bprop_mirror_micro_step_operator(self): z = F.depend(z, dout) real_grad = all_reduce(z) assign(z, real_grad) + if opt_shard: + return (real_grad, cast(out_tensor, dtype(z))) return (cast(out_tensor, dtype(x)), cast(out_tensor, dtype(z))) return bprop @@ -205,30 +213,17 @@ def get_bprop_broad_cast(self): def get_bprop_all_gather(self): """Generate bprop for AllGather""" fusion = self.get_attr_dict()["fusion"] - if fusion == 0: - reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group) - if self.instance_name: - instance_name = "grad_" + self.instance_name - reduce_scatter.set_prim_instance_name(instance_name) - else: - all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) - if self.instance_name: - instance_name = "grad_" + self.instance_name - all_reduce.set_prim_instance_name(instance_name) - rank = get_rank(self.group) - dev_num = get_group_size(self.group) - split = P.Split(output_num=dev_num) - mean_flag = self.get_attr_dict()["mean_flag"] - scale = 1/self.rank_size + reduce_scatter = ReduceScatter(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) + if self.instance_name: + instance_name = "grad_" + self.instance_name + reduce_scatter.set_prim_instance_name(instance_name) + mean_flag = self.get_attr_dict()["mean_flag"] + scale = 1 / self.rank_size def bprop(x, out, dout): - if fusion == 0: - dx = reduce_scatter(dout) - else: - grad = all_reduce(dout) - dx = split(grad)[rank] - if mean_flag: - dx = F.tensor_mul(dx, scale) + dx = reduce_scatter(dout) + if mean_flag: + dx = F.tensor_mul(dx, scale) return (dx,) return bprop @@ -267,6 +262,35 @@ def get_bprop_mini_step_all_gather(self): return bprop +@bprop_getters.register(_MicroStepAllGather) +def get_bprop_micro_step_all_gather(self): + """Generate bprop for _MicroStepAllGather""" + fusion = self.get_attr_dict()["fusion"] + mean_flag = self.get_attr_dict()["mean_flag"] + scale = 1 / self.rank_size + all_reduce = AllReduce(ReduceOp.SUM, self.group).add_prim_attr("fusion", fusion) + rank = get_rank(self.group) + dev_num = get_group_size(self.group) + split = P.Split(output_num=dev_num) + if self.instance_name: + instance_name = "grad_" + self.instance_name + all_reduce.set_prim_instance_name(instance_name) + cast = P.Cast() + dtype = P.DType() + out_tensor = Tensor(1.0, mstype.float16) + + # z: accu_grad + def bprop(x, z, out, dout): + z = F.depend(z, dout) + real_grad = all_reduce(z) + real_grad = split(real_grad)[rank] + if mean_flag: + real_grad = F.tensor_mul(real_grad, scale) + return (real_grad, cast(out_tensor, dtype(z))) + + return bprop + + @bprop_getters.register(_HostAllGather) def get_bprop_host_all_gather(self): """Generate bprop for _HostAllGather""" diff --git a/mindspore/ops/bprop_mindir/Identity_bprop.mindir b/mindspore/ops/bprop_mindir/Identity_bprop.mindir index a5c3e5eb9ba..2d32735756b 100644 --- a/mindspore/ops/bprop_mindir/Identity_bprop.mindir +++ b/mindspore/ops/bprop_mindir/Identity_bprop.mindir @@ -6,4 +6,4 @@ bprop.10:x* bprop.10:out* bprop.10:dout2 -bprop.10:[CNode]12:2:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file +bprop.10:[CNode]12:2:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file diff --git a/mindspore/ops/bprop_mindir/ReLU_bprop.mindir b/mindspore/ops/bprop_mindir/ReLU_bprop.mindir index 132c7d78e20..f78ba9ad56f 100644 --- a/mindspore/ops/bprop_mindir/ReLU_bprop.mindir +++ b/mindspore/ops/bprop_mindir/ReLU_bprop.mindir @@ -8,4 +8,4 @@ bprop.2:x* bprop.2:out* bprop.2:dout2 -bprop.2:[CNode]4:3:€14cac93a068aa39edcd5220275a7f3df23c79f939b5f52bbe3321d22bc4706d92366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file +bprop.2:[CNode]4:3:€027af68f320ba40d9fbd0893da424c07f9c3a4ec82e98f9543bff9b5a15547a22366f7bd59ea5ec135e982de03b4f7cab6b61d833d046a6e13f78bdaf2fb2b22248b4695c64d61a01e33ef3ba7144b288e54122debb351a5ac8f55d8914329584c332efad4a51b4773cb78093dd53a4ca850b2dc6cdd5f2ae47106b3fda77bb3522819d4919298eadafe049d3d0f3f1998cec40b35bed9c51c9d28b44ea7726065c0e00bc893ef15ec6199798d6c8c46997153587d375b3240c1195ff2c7278c7e635a08323207b4cb3f73fd8437b4d7ee28a7676a68f005a7749bd19e5ed4eca6c407ad6a3b57190d3702d6a45031d13b97bb6952735edf94fb36f73dbff6cdab258748286fc6d783abacce203dfc79d2fc31e23a427ce1f86e08777a687f71c0606bdbf14ec1b2b2d86ab82b5eb2ac71f1d3d0ba743f7cee45a1d9a0a2d82ac414b8c313aac4f85c6217fbbb7009dd079b2d5548f8b695a470a11cb8cc83e6f5e78f5b3c67f2e7bf339b250c3638aee952e1a073002e2834011401f3827260 \ No newline at end of file diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index e34caa78176..69b4c4f6870 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -37,7 +37,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, _VirtualOutput, _VirtualDiv, _GetTensorSlice, _VirtualAdd, _VirtualAssignAdd, _VirtualAccuGrad, - _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator) + _HostAllGather, _HostReduceScatter, _MirrorMicroStepOperator, _MicroStepAllGather) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) from .control_ops import GeSwitch, Merge diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 043cd9cf8c3..dd5982e6164 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -191,6 +191,7 @@ class AllGather(PrimitiveWithInfer): self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('group', _get_group(group)) self.add_prim_attr('fusion', 0) + self.add_prim_attr('mean_flag', False) def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) @@ -239,6 +240,36 @@ class _MiniStepAllGather(PrimitiveWithInfer): return x_dtype +class _MicroStepAllGather(PrimitiveWithInfer): + """ + Auto parallel virtual operator. Do nothing in forward, do reducescatter in backward in mini-step. It is only for + internal use of parallel modules and cannot be called by users. + + Args: + group (str): The communication group to work on. Default: None. + """ + @prim_attr_register + def __init__(self, group=GlobalComm.WORLD_COMM_GROUP, mean_flag=None): + validator.check_value_type('group', _get_group(group), (str,), self.name) + self.rank = get_rank(_get_group(group)) + self.rank_size = get_group_size(_get_group(group)) + validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) + self.add_prim_attr('rank_size', self.rank_size) + self.add_prim_attr('group', _get_group(group)) + self.add_prim_attr('fusion', 1) + self.mean_flag = mean_flag + + def infer_shape(self, x_shape, z_shape): + validator.check_positive_int(len(x_shape), "x shape", self.name) + if x_shape[0] > 0: + x_shape[0] = x_shape[0] * self.rank_size + return x_shape + + def infer_dtype(self, x_dtype, z_dtype): + validator.check_tensor_dtype_valid('x', x_dtype, target_dtypes, self.name) + return x_dtype + + class _HostAllGather(PrimitiveWithInfer): """ Gathers tensors from the specified communication group on host. diff --git a/mindspore/parallel/_utils.py b/mindspore/parallel/_utils.py index a959314ca55..23c3112dae2 100644 --- a/mindspore/parallel/_utils.py +++ b/mindspore/parallel/_utils.py @@ -160,6 +160,11 @@ def _get_parameter_broadcast(): return parameter_broadcast +def _get_enable_parallel_optimizer(): + """Get if using parallel optimizer.""" + return auto_parallel_context().get_enable_parallel_optimizer() + + def _device_number_check(parallel_mode, device_number): """ Check device num. diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py index 6d4d1bf1bab..9204d04286c 100644 --- a/tests/ut/python/parallel/test_pipeline_split.py +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -173,3 +173,67 @@ def test_pipeline_split_shared_parameter_stage1(): optimizer = nn.Lamb(params, learning_rate=0.01) model = Model(net, optimizer=optimizer) model.train(2, dataset, dataset_sink_mode=False) + +def test_pipeline_split_stage0_opt_shard(): + context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) + params = net.network.cell.block[0].trainable_params() + dataset = DatasetLenet(data, label, 3) + optimizer = nn.Lamb(params, learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) + for _, param in model._train_network.parameters_and_names(): + assert param.name != "cell.block.1.param" + assert param.name != "cell.block.1.param1" + +def test_pipeline_split_stage1_opt_shard(): + context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineCell(PipelineSplit(strategy1, strategy2), 4) + params = net.network.cell.block[1].trainable_params() + dataset = DatasetLenet(data, label, 3) + optimizer = nn.Lamb(params, learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) + for _, param in model._train_network.parameters_and_names(): + assert param.name != "cell.block.0.param" + assert param.name != "cell.block.0.param1" + + +def test_pipeline_split_shared_parameter_stage0_opt_shard(): + context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) + params = net.network.cell.block[0].trainable_params() + dataset = DatasetLenet(data, label, 3) + optimizer = nn.Lamb(params, learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False) + + +def test_pipeline_split_shared_parameter_stage1_opt_shard(): + context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4) + params = net.network.cell.block[1].trainable_params() + dataset = DatasetLenet(data, label, 3) + optimizer = nn.Lamb(params, learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False)