diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 21d776a0677..60432484bc6 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -98,7 +98,7 @@ class DeviceManager { std::map group_to_rank_; // the key is hash name, value is rank list int64_t global_rank_ = 0; // the real rank in all devices - int64_t stage_num_ = 0; // the stage num + int64_t stage_num_ = 1; // the stage num int64_t stage_id_ = 0; // the stage id of the global_rank_ int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage int64_t stage_device_num_ = 0; // the device num of one stage diff --git a/mindspore/ccsrc/frontend/parallel/node_check.cc b/mindspore/ccsrc/frontend/parallel/node_check.cc index 1760d150faa..99b5625621b 100644 --- a/mindspore/ccsrc/frontend/parallel/node_check.cc +++ b/mindspore/ccsrc/frontend/parallel/node_check.cc @@ -75,7 +75,8 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, EMBED, CREATINSTANCE, REF_TO_EMBED, - STOP_GRADIENT}; + STOP_GRADIENT, + SEND}; const std::set BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER}; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 00497b54acd..201d5d6d384 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -182,6 +182,8 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog constexpr char MATMUL[] = "MatMul"; constexpr char GELU[] = "Gelu"; constexpr char TANH[] = "Tanh"; +constexpr char RECEIVE[] = "Receive"; +constexpr char SEND[] = "Send"; constexpr char SHAPE_OP[] = "Shape"; constexpr char SOFTMAX[] = "Softmax"; constexpr char LOG_SOFTMAX[] = "LogSoftmax"; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index d0757bf9924..4b6687015ba 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -26,6 +26,8 @@ #include "frontend/parallel/ops_info/ops_utils.h" #include "frontend/parallel/group_manager.h" #include "frontend/parallel/context.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/node_check.h" #include "utils/comm_manager.h" #include "utils/ms_context.h" @@ -37,6 +39,7 @@ static int recv_tag = 0; void PipelineTransformer::Coloring() { auto need_coloring = true; + std::set stage_set; while (need_coloring) { need_coloring = false; for (auto &fg : manager_->func_graphs()) { @@ -52,6 +55,9 @@ void PipelineTransformer::Coloring() { auto user_node = user_pair.first->cast(); user_node->set_stage(graph->stage()); auto user_node_graph = user_node->func_graph(); + if (graph->stage() != -1) { + stage_set.insert(graph->stage()); + } if (graph->stage() == stage_ && user_node_graph->stage() == -1) { user_node_graph->set_stage(graph->stage()); need_coloring = true; @@ -60,6 +66,12 @@ void PipelineTransformer::Coloring() { } } } + MS_EXCEPTION_IF_NULL(g_device_manager); + auto stage_num = g_device_manager->stage_num(); + if (SizeToInt(stage_set.size()) != stage_num) { + MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size(); + } + return; } void PipelineTransformer::BroadCastColoring() { @@ -68,6 +80,96 @@ void PipelineTransformer::BroadCastColoring() { } } +bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + auto prim = GetValueNode(cnode->input(0)); + if (prim == nullptr) { + return false; + } + if (IsInBlackList(prim)) { + MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name(); + return false; + } + return true; +} + +OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPipelineCareNode(cnode)) { + MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node."; + } + auto shape_list = ExtractShape(cnode); + if (shape_list.empty()) { + MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape."; + } + auto prim = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == RESHAPE) { + MS_LOG(EXCEPTION) << "Reshape op can't be a border."; + } + auto attrs = prim->attrs(); + auto op_info = OperatorInstance(prim, attrs, shape_list); + auto &inputs = cnode->inputs(); + std::vector input_value; + for (size_t index = 1; index < inputs.size(); ++index) { + if (inputs[index]->isa()) { + input_value.push_back(GetValueNode(inputs[index])); + } else { + input_value.emplace_back(nullptr); + } + } + op_info->set_input_value(input_value); + op_info->set_outputs_dtype(cnode->Type()); + op_info->set_cnode(cnode); + StrategyPtr strategy = nullptr; + if (!StrategyFound(attrs)) { + strategy = GenerateBatchParallelStrategy(op_info, prim); + } else { + strategy = ExtractStrategy(attrs); + } + MS_EXCEPTION_IF_NULL(strategy); + if (op_info->Init(strategy) == FAILED) { + MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed."; + } + return op_info; +} + +std::pair PipelineTransformer::GetOpInfo(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + OperatorInfoPtr op_info = nullptr; + TensorInfo tensor_info; + // op1(stage1)->op2(stage2) + if (IsValueNode(cnode->input(0))) { + op_info = CreateOpInfo(cnode); + MS_EXCEPTION_IF_NULL(op_info); + tensor_info = op_info->outputs_tensor_info()[0]; + } else if (IsValueNode(cnode->input(0))) { + auto graph = GetValueNode(cnode->input(0)); + MS_EXCEPTION_IF_NULL(graph); + auto output = graph->output(); + MS_EXCEPTION_IF_NULL(output); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto prim = GetValueNode(output_cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == TUPLE_GETITEM) { + auto index = GetTupleGetItemIndex(output_cnode); + auto pre_getitem_node = output_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_getitem_node); + op_info = CreateOpInfo(pre_getitem_node); + MS_EXCEPTION_IF_NULL(op_info); + tensor_info = op_info->outputs_tensor_info()[index]; + } else { + op_info = CreateOpInfo(output_cnode); + MS_EXCEPTION_IF_NULL(op_info); + tensor_info = op_info->outputs_tensor_info()[0]; + } + } + return std::make_pair(op_info, std::make_shared(tensor_info)); +} + void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) { auto need_coloring = true; while (need_coloring) { @@ -168,26 +270,19 @@ void PipelineTransformer::ParameterColoring() { } } -static std::pair GetShapeType(const AnfNodePtr &node) { - abstract::ShapePtr shape_ptr; +static std::pair GetShapeType(const AnfNodePtr &node, const Shape &shape) { TypePtr type; - std::vector shape; auto cnode = node->cast(); if (cnode != nullptr && IsValueNode(cnode->input(0))) { auto graph = GetValueNode(cnode->input(0)); auto graph_return = graph->get_return(); - shape_ptr = dyn_cast(graph_return->Shape()); type = graph_return->Type(); } else { - shape_ptr = dyn_cast(node->Shape()); type = node->Type(); } - MS_EXCEPTION_IF_NULL(shape_ptr); MS_EXCEPTION_IF_NULL(type); - auto shape_int = shape_ptr->shape(); std::vector element; - std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(element), - [](int elem) { return MakeValue(elem); }); + std::transform(shape.begin(), shape.end(), std::back_inserter(element), [](int elem) { return MakeValue(elem); }); auto shape_list = std::make_shared(element); auto tensor_type = type->cast(); MS_EXCEPTION_IF_NULL(tensor_type); @@ -203,16 +298,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank)); OperatorAttrs attrs = {attr_tag, attr_rank}; - auto send_op = CreatOpInstance(attrs, "Send", "send"); + auto send_op = CreatOpInstance(attrs, SEND, "send"); auto send_node = NewValueNode(send_op); auto prim = GetValueNode(send_node); - auto shape_type_pair = GetShapeType(parameter); + auto op_info_pair = GetOpInfo(parameter); + auto tensor_info = op_info_pair.second; + MS_EXCEPTION_IF_NULL(tensor_info); + auto slice_shape = tensor_info->slice_shape(); + auto shape_type_pair = GetShapeType(parameter, slice_shape); prim->set_attr("shape", shape_type_pair.first); prim->set_attr("dtype", shape_type_pair.second); std::vector send_input = {send_node, parameter}; auto send = graph->NewCNode(send_input); OperatorAttrs depend_attrs; - auto depend_op = CreatOpInstance(depend_attrs, "Depend", "depend"); + auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend"); std::vector depend_input = {NewValueNode(depend_op), parameter, send}; auto depend = graph->NewCNode(depend_input); SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend}; @@ -223,15 +322,23 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode int index, int user_node_stage, int node_stage) { Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag)); recv_tag += 1; - auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_; + auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_; Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank)); - auto shape_type_pair = GetShapeType(node); + auto op_info_pair = GetOpInfo(node); + auto tensor_info = op_info_pair.second; + MS_EXCEPTION_IF_NULL(tensor_info); + auto slice_shape = tensor_info->slice_shape(); + auto shape_type_pair = GetShapeType(node, slice_shape); Attr attr_shape = std::make_pair("shape", shape_type_pair.first); Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second); OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype}; - auto recv_op = CreatOpInstance(attrs, "Receive", "recv"); + auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv"); std::vector recv_input = {NewValueNode(recv_op), virtual_param_}; auto recv = graph->NewCNode(recv_input); + auto node_abstract = node->abstract(); + recv->set_abstract(node_abstract); + recv->set_user_data(std::make_shared(tensor_info->tensor_layout())); + recv->set_user_data(op_info_pair.first); manager_->SetEdge(use_node, index, recv); } @@ -317,36 +424,10 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) { void PipelineTransformer::CutGraph() { for (auto &fg : manager_->func_graphs()) { - if (fg == root_) { - ElimRootParameter(); - continue; - } CutBorder(fg); } } -void PipelineTransformer::ElimRootParameter() { - auto output = root_->output()->cast(); - MS_EXCEPTION_IF_NULL(output); - auto prim = GetValueNode(output->input(0)); - if (prim->name() == DEPEND) { - auto opt_cnode = output->input(2)->cast(); - auto prim_make_tuple = GetValueNode(opt_cnode->input(0)); - if (prim_make_tuple->name() == MAKE_TUPLE) { - std::vector new_node_input = {opt_cnode->input(0)}; - for (auto &input : opt_cnode->inputs()) { - if (input->isa()) { - if (IsStageNode(input->cast())) { - new_node_input.push_back(input); - } - } - } - auto new_node = root_->NewCNode(new_node_input); - manager_->Replace(opt_cnode, new_node); - } - } -} - bool PipelineTransformer::IsStageNode(const CNodePtr &node) { for (auto &input : node->inputs()) { if (input->isa()) { @@ -414,11 +495,16 @@ std::pair PipelineTransformer::FindSensNode() { } void PipelineTransformer::CoverSensShape() { + if (IsLastStage()) { + return; + } auto sens_graph_pair = FindSensNode(); auto sens_cnode = sens_graph_pair.first; MS_EXCEPTION_IF_NULL(sens_cnode); OperatorAttrs attrs; auto fill_op = CreatOpInstance(attrs, "Fill", ""); + MS_EXCEPTION_IF_NULL(type_ptr_); + MS_EXCEPTION_IF_NULL(shape_); std::vector fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_), NewValueNode(MakeValue(shape_->value())), NewValueNode(0)}; auto fill = root_->NewCNode(fill_input); diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index cdfaf040e94..8397f15b164 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -19,13 +19,18 @@ #include #include +#include #include "ir/value.h" #include "ir/graph_utils.h" #include "base/base.h" +#include "frontend/parallel/step_parallel.h" #include "frontend/parallel/graph_util/generate_graph.h" namespace mindspore { namespace parallel { +using TensorLayoutPtr = std::shared_ptr; +using TensorInfoPtr = std::shared_ptr; + typedef struct { ValueListPtr shape; TypePtr type; @@ -59,8 +64,10 @@ class PipelineTransformer { void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index, int user_node_stage, int node_stage); void CutBorder(const FuncGraphPtr &graph); - void ElimRootParameter(); bool IsStageNode(const CNodePtr &node); + std::pair GetOpInfo(const AnfNodePtr &node); + OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); + bool IsPipelineCareNode(const CNodePtr &cnode); std::pair FindSensNode(); FuncGraphManagerPtr manager_; int64_t stage_; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index edb62f55218..8077b041f47 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -1752,7 +1752,7 @@ void ExtractInformation(const std::vector &all_nodes, bool is_traini SetVirtualDatasetStrategy(cnode); ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); - if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) { + if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) { continue; } auto attrs = prim->attrs(); @@ -2420,6 +2420,13 @@ std::vector> GetSensLossPairs(const FuncGraphP return sens_loss_pairs; } +bool IsLastStage() { + MS_EXCEPTION_IF_NULL(g_device_manager); + auto stage_num = g_device_manager->stage_num(); + auto stage_id = g_device_manager->stage_id(); + return ((stage_num - 1) == stage_id); +} + void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); @@ -2432,7 +2439,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorisa()) { auto cnode = node->cast(); - if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data() || IsSomePrimitive(cnode, RECEIVE)) { continue; } @@ -2895,7 +2906,7 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN for (auto &candidate : candidate_set) { auto candidate_node = candidate.first; auto c = candidate_node->cast(); - if (c == nullptr || !c->has_user_data()) { + if (c == nullptr || !c->has_user_data() || IsSomePrimitive(c, RECEIVE)) { continue; } (void)parameter_user_info.second.second.insert(candidate); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index c78729d3806..ab4ecdf1019 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -131,6 +131,10 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node); void ReshapeInit(const std::vector &all_nodes); +StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim); + +bool IsLastStage(); + // Add node for whole graph void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc index 594c277a0e9..6f3ce79a19c 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline_split.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline_split.cc @@ -21,6 +21,7 @@ #include "utils/comm_manager.h" #include "frontend/parallel/context.h" #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h" +#include "frontend/parallel/step_parallel.h" namespace mindspore { namespace pipeline { @@ -59,7 +60,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num // Only auto_parallel and semi_auto_parallel support PipelineSplit bool PipelineSplit(const ResourcePtr &res) { auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); - if (parallel_mode != parallel::SEMI_AUTO_PARALLEL || parallel_mode != parallel::AUTO_PARALLEL) { + if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) { MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split."; return true; } @@ -80,6 +81,9 @@ bool PipelineSplit(const ResourcePtr &res) { } auto stage = InferStage(global_rank, stage_num, device_num); auto per_stage_rank_num = device_num / stage_num; + if (parallel::ParallelInit() != parallel::SUCCESS) { + MS_LOG(EXCEPTION) << "parallel init failed."; + } auto transformer = std::make_shared(manager, stage, root, global_rank, per_stage_rank_num); // step1: Do color graph diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 684ed98b119..4278fccf7f7 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -20,9 +20,10 @@ from .. import operations as P from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, - _GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive, + _GetTensorSlice, _MirrorOperator, ReduceOp, ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) from .grad_base import bprop_getters +from ..operations._inner_ops import Send, Receive @bprop_getters.register(AllReduce) diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 5a45150634a..ae71dd1f3b4 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack, Unique, GatherD, Identity, SequenceMask) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, ReduceOp, _VirtualDataset, - _VirtualDiv, _GetTensorSlice, Send, Receive, + _VirtualDiv, _GetTensorSlice, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) diff --git a/mindspore/ops/operations/_inner_ops.py b/mindspore/ops/operations/_inner_ops.py index e091e397452..eb8cb4fdb66 100644 --- a/mindspore/ops/operations/_inner_ops.py +++ b/mindspore/ops/operations/_inner_ops.py @@ -21,6 +21,7 @@ from ... import context from ...common import dtype as mstype from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register from ..operations.math_ops import _infer_shape_reduce +from ...communication.management import get_rank, GlobalComm, _get_group class ExtractImagePatches(PrimitiveWithInfer): @@ -371,6 +372,116 @@ class MatrixDiagPart(PrimitiveWithInfer): return out_shape +class Send(PrimitiveWithInfer): + """ + Send tensors from src_rank to the specified dest_rank. + + Note: + Send and Recveive must be used in combination and have same sr_tag. + Send must be used between servers. + + Args: + sr_tag (int): A required integer identifying the send/recv message tag. The message will + will be received by the Receive op with the same "sr_tag". + dest_rank (int): A required integer identifying the destination rank. + group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Examples: + >>> import mindspore.ops.operations as ops + >>> import mindspore.nn as nn + >>> from mindspore.communication import init + >>> from mindspore import Tensor + >>> import numpy as np + >>> + >>> init() + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.depend = ops.Depend() + >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") + >>> + >>> def construct(self, x): + >>> out = self.depend(x, self.send(x)) + >>> return out + >>> + >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) + >>> net = Net() + >>> output = net(input_) + """ + @prim_attr_register + def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): + self.rank = get_rank(_get_group(group)) + self.sr_tag = sr_tag + self.group = group + + def infer_shape(self, x_shape): + self.add_prim_attr("shape", x_shape) + return x_shape + + def infer_dtype(self, x_dtype): + self.add_prim_attr("dtype", x_dtype) + return x_dtype + + +class Receive(PrimitiveWithInfer): + """ + receive tensors from src_rank. + + Note: + Send and Recveive must be used in combination and have same sr_tag. + Receive must be used between servers. + + Args: + sr_tag (int): A required integer identifying the send/recv message tag. The message will + will be send by the Send op with the same "sr_tag". + src_rank (int): A required integer identifying the source rank. + shape (list[int]): A required list identifying the shape of the tensor to be received. + dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types: + int8, int16, int32, float16, float32. + group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". + + Inputs: + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + + Examples: + >>> import mindspore.ops.operations as ops + >>> import mindspore.nn as nn + >>> from mindspore.communication import init + >>> from mindspore import Tensor + >>> import numpy as np + >>> + >>> init() + >>> class Net(nn.Cell): + >>> def __init__(self): + >>> super(Net, self).__init__() + >>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, + >>> group="hccl_world_group") + >>> + >>> def construct(self): + >>> out = self.recv() + >>> return out + >>> + >>> net = Net() + >>> output = net() + """ + @prim_attr_register + def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): + self.rank = get_rank(_get_group(group)) + self.tag = sr_tag + self.shape = shape + self.dtype = dtype + self.group = group + + def infer_shape(self, x_shape=None): + return self.shape + + def infer_dtype(self, x_dtype=None): + return self.dtype + + class MatrixSetDiag(PrimitiveWithInfer): r""" Modifies the batched diagonal part of a batched tensor. diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index c2bfa1f178d..c7e46fdf53c 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -116,117 +116,6 @@ class AllReduce(PrimitiveWithInfer): return x_dtype -class Send(PrimitiveWithInfer): - """ - Send tensors from src_rank to the specified dest_rank. - - Note: - Send and Recveive must be used in combination and have same sr_tag. - Send must be used between servers. - - Args: - sr_tag (int): A required integer identifying the send/recv message tag. The message will - will be received by the Receive op with the same "sr_tag". - dest_rank (int): A required integer identifying the destination rank. - group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". - - Inputs: - - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - - Examples: - >>> import mindspore.ops.operations as ops - >>> import mindspore.nn as nn - >>> from mindspore.communication import init - >>> from mindspore import Tensor - >>> import numpy as np - >>> - >>> init() - >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.depend = ops.Depend() - >>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group") - >>> - >>> def construct(self, x): - >>> out = self.depend(x, self.send(x)) - >>> return out - >>> - >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) - >>> net = Net() - >>> output = net(input_) - """ - @prim_attr_register - def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP): - self.rank = get_rank(_get_group(group)) - self.sr_tag = sr_tag - self.group = group - - def infer_shape(self, x_shape): - self.add_prim_attr("shape", x_shape) - return x_shape - - def infer_dtype(self, x_dtype): - self.add_prim_attr("dtype", x_dtype) - return x_dtype - - -class Receive(PrimitiveWithInfer): - """ - receive tensors from src_rank. - - Note: - Send and Recveive must be used in combination and have same sr_tag. - Receive must be used between servers. - - Args: - sr_tag (int): A required integer identifying the send/recv message tag. The message will - will be send by the Send op with the same "sr_tag". - src_rank (int): A required integer identifying the source rank. - shape (list[int]): A required list identifying the shape of the tensor to be received. - dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types: - int8, int16, int32, float16, float32. - group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group". - - Inputs: - - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. - - Examples: - >>> import mindspore.ops.operations as ops - >>> import mindspore.nn as nn - >>> from mindspore.communication import init - >>> from mindspore import Tensor - >>> import numpy as np - >>> - >>> init() - >>> class Net(nn.Cell): - >>> def __init__(self): - >>> super(Net, self).__init__() - >>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32, - >>> group="hccl_world_group") - >>> - >>> def construct(self, x): - >>> out = self.depend(x, self.recv(x)) - >>> return out - >>> - >>> input_ = Tensor(np.ones([2, 8]).astype(np.float32)) - >>> net = Net() - >>> output = net(input_) - """ - @prim_attr_register - def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP): - self.rank = get_rank(_get_group(group)) - self.tag = sr_tag - self.shape = shape - self.dtype = dtype - self.group = group - - def infer_shape(self, x_shape=None): - return self.shape - - def infer_dtype(self, x_dtype=None): - return self.dtype - - class AllGather(PrimitiveWithInfer): """ Gathers tensors from the specified communication group. diff --git a/tests/st/nccl/test_nccl_send_recv_op.py b/tests/st/nccl/test_nccl_send_recv_op.py index 37a25c11051..786c4d6d6b0 100644 --- a/tests/st/nccl/test_nccl_send_recv_op.py +++ b/tests/st/nccl/test_nccl_send_recv_op.py @@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size from mindspore.ops import operations as P +from mindspore.ops.operations._inner_ops import Send, Receive from mindspore.common import dtype as mstype context.set_context(mode=context.GRAPH_MODE, device_target='GPU') @@ -38,7 +39,7 @@ class SendNet(nn.Cell): super(SendNet, self).__init__() self.x = Parameter(initializer(Tensor(x), x.shape), name='x') self.depend = P.Depend() - self.send = P.Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) + self.send = Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP) def construct(self): out = self.depend(self.x, self.send(self.x)) @@ -47,8 +48,8 @@ class SendNet(nn.Cell): class RecvNet(nn.Cell): def __init__(self): super(RecvNet, self).__init__() - self.recv = P.Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, - group=NCCL_WORLD_COMM_GROUP) + self.recv = Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32, + group=NCCL_WORLD_COMM_GROUP) def construct(self): out = self.recv() diff --git a/tests/ut/python/parallel/test_pipeline_parallel.py b/tests/ut/python/parallel/test_pipeline_parallel.py deleted file mode 100644 index 3f8147d2e90..00000000000 --- a/tests/ut/python/parallel/test_pipeline_parallel.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import mindspore as ms -import mindspore.nn as nn -from mindspore import Tensor -from mindspore import context -from mindspore.common.api import _executor -from mindspore.ops import composite as C -from mindspore.ops import operations as P -from tests.ut.python.ops.test_math_ops import VirtualLoss - - -grad_all = C.GradOperation(get_all=True) - - -class NetWithLoss(nn.Cell): - def __init__(self, network): - super(NetWithLoss, self).__init__() - self.loss = VirtualLoss() - self.network = network - - def construct(self, x, y): - predict = self.network(x, y) - return self.loss(predict) - - -class GradWrap(nn.Cell): - def __init__(self, network): - super(GradWrap, self).__init__() - self.network = network - - def construct(self, x, y): - return grad_all(self.network)(x, y) - - -class Net(nn.Cell): - def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""): - super().__init__() - if shape is None: - shape = [64, 64] - self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target) - self.mul = P.Mul().shard(strategy2) - self.index = Tensor(np.ones(shape), dtype=ms.int32) - self.gatherv2.set_stage(stage1) - self.mul.set_stage(stage2) - self.axis = axis - - def construct(self, x, y): - out = self.gatherv2(x, self.index, self.axis) - out = self.mul(out, y) - return out - - -def test_gatherv2_semi_samestage1(): - context.set_auto_parallel_context(device_num=8, global_rank=0, \ - parallel_mode="semi_auto_parallel", pipeline_stages=2) - strategy1 = ((1, 2), (1, 1)) - strategy2 = ((2, 1, 1), (2, 1, 1)) - net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2))) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - net.set_train() - _executor.compile(net, x, y) - -def test_gatherv2_semi_samestage2(): - context.set_auto_parallel_context(device_num=8, global_rank=5, \ - parallel_mode="semi_auto_parallel", pipeline_stages=2) - strategy1 = ((1, 2), (1, 1)) - strategy2 = ((2, 1, 1), (2, 1, 1)) - net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2))) - net.set_auto_parallel() - - x = Tensor(np.ones([64, 64]), dtype=ms.float32) - y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32) - net.set_train() - _executor.compile(net, x, y) diff --git a/tests/ut/python/parallel/test_pipeline_split.py b/tests/ut/python/parallel/test_pipeline_split.py new file mode 100644 index 00000000000..5866b56c0e3 --- /dev/null +++ b/tests/ut/python/parallel/test_pipeline_split.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import context +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.train.model import Model + + +class DatasetLenet(): + def __init__(self, data, label, length=3): + self.data = data + self.label = label + self.index = 1 + self.length = length + + def __iter__(self): + return self + + def __next__(self): + if self.index >= self.length: + raise StopIteration + self.index += 1 + return self.data, self.label + + def reset(self): + self.index = 0 + + def get_dataset_size(self): + return 32 + + def get_repeat_count(self): + return 1 + + def get_batch_size(self): + return 32 + + def create_tuple_iterator(self, num_epochs=1): + return self + + +class MatMulCell(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.param = Parameter(initializer("zeros", [64, 64]), name="param") + self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1") + self.matmul = P.MatMul().shard(strategy1) + self.matmul1 = P.MatMul().shard(strategy2) + + def construct(self, x): + out = self.matmul(x, self.param) + out = self.matmul1(out, self.param1) + return out + + +class Net(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.block = nn.CellList() + for i in range(2): + cell = MatMulCell(strategy1, strategy2) + cell.stage = i + self.block.append(cell) + + def construct(self, x): + for i in range(2): + x = self.block[i](x) + return x + + +class PipelineSplit(nn.Cell): + def __init__(self, strategy1, strategy2): + super().__init__() + self.cell = Net(strategy1, strategy2) + + def construct(self, x, label): + x = self.cell(x) + return x + + +def test_pipeline_split(): + context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + data = Tensor(np.ones([32, 64]), dtype=ms.float32) + label = Tensor(np.ones([64, 64]), dtype=ms.float32) + strategy1 = ((4, 1), (1, 1)) + strategy2 = ((2, 1), (1, 1)) + net = PipelineSplit(strategy1, strategy2) + params = net.cell.block[1].trainable_params() + dataset = DatasetLenet(data, label, 3) + optimizer = nn.Lamb(params, learning_rate=0.01) + model = Model(net, optimizer=optimizer) + model.train(2, dataset, dataset_sink_mode=False)