pipeline_split adapt parallel

This commit is contained in:
lichenever 2020-12-01 15:23:26 +08:00
parent cffe2c94fe
commit 78e131cf15
15 changed files with 393 additions and 258 deletions

View File

@ -98,7 +98,7 @@ class DeviceManager {
std::map<std::string, std::string> group_to_rank_; // the key is hash name, value is rank list
int64_t global_rank_ = 0; // the real rank in all devices
int64_t stage_num_ = 0; // the stage num
int64_t stage_num_ = 1; // the stage num
int64_t stage_id_ = 0; // the stage id of the global_rank_
int64_t rank_index_in_stage_ = 0; // the index of this rank in it's stage
int64_t stage_device_num_ = 0; // the device num of one stage

View File

@ -75,7 +75,8 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
EMBED,
CREATINSTANCE,
REF_TO_EMBED,
STOP_GRADIENT};
STOP_GRADIENT,
SEND};
const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER};

View File

@ -182,6 +182,8 @@ constexpr char SIGMOID_CROSS_ENTROPY_WITH_LOGITS[] = "SigmoidCrossEntropyWithLog
constexpr char MATMUL[] = "MatMul";
constexpr char GELU[] = "Gelu";
constexpr char TANH[] = "Tanh";
constexpr char RECEIVE[] = "Receive";
constexpr char SEND[] = "Send";
constexpr char SHAPE_OP[] = "Shape";
constexpr char SOFTMAX[] = "Softmax";
constexpr char LOG_SOFTMAX[] = "LogSoftmax";

View File

@ -26,6 +26,8 @@
#include "frontend/parallel/ops_info/ops_utils.h"
#include "frontend/parallel/group_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/node_check.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
@ -37,6 +39,7 @@ static int recv_tag = 0;
void PipelineTransformer::Coloring() {
auto need_coloring = true;
std::set<int64_t> stage_set;
while (need_coloring) {
need_coloring = false;
for (auto &fg : manager_->func_graphs()) {
@ -52,6 +55,9 @@ void PipelineTransformer::Coloring() {
auto user_node = user_pair.first->cast<CNodePtr>();
user_node->set_stage(graph->stage());
auto user_node_graph = user_node->func_graph();
if (graph->stage() != -1) {
stage_set.insert(graph->stage());
}
if (graph->stage() == stage_ && user_node_graph->stage() == -1) {
user_node_graph->set_stage(graph->stage());
need_coloring = true;
@ -60,6 +66,12 @@ void PipelineTransformer::Coloring() {
}
}
}
MS_EXCEPTION_IF_NULL(g_device_manager);
auto stage_num = g_device_manager->stage_num();
if (SizeToInt(stage_set.size()) != stage_num) {
MS_LOG(EXCEPTION) << "Stage num is " << stage_num << " is not equal to stage used: " << stage_set.size();
}
return;
}
void PipelineTransformer::BroadCastColoring() {
@ -68,6 +80,96 @@ void PipelineTransformer::BroadCastColoring() {
}
}
bool PipelineTransformer::IsPipelineCareNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
if (prim == nullptr) {
return false;
}
if (IsInBlackList(prim)) {
MS_LOG(INFO) << "PipelineSplit don't care node:" << prim->name();
return false;
}
return true;
}
OperatorInfoPtr PipelineTransformer::CreateOpInfo(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (!IsPipelineCareNode(cnode)) {
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " is not a Pipeline Care Node.";
}
auto shape_list = ExtractShape(cnode);
if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Node: " << cnode->ToString() << " failed to extract shape.";
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == RESHAPE) {
MS_LOG(EXCEPTION) << "Reshape op can't be a border.";
}
auto attrs = prim->attrs();
auto op_info = OperatorInstance(prim, attrs, shape_list);
auto &inputs = cnode->inputs();
std::vector<ValuePtr> input_value;
for (size_t index = 1; index < inputs.size(); ++index) {
if (inputs[index]->isa<ValueNode>()) {
input_value.push_back(GetValueNode(inputs[index]));
} else {
input_value.emplace_back(nullptr);
}
}
op_info->set_input_value(input_value);
op_info->set_outputs_dtype(cnode->Type());
op_info->set_cnode(cnode);
StrategyPtr strategy = nullptr;
if (!StrategyFound(attrs)) {
strategy = GenerateBatchParallelStrategy(op_info, prim);
} else {
strategy = ExtractStrategy(attrs);
}
MS_EXCEPTION_IF_NULL(strategy);
if (op_info->Init(strategy) == FAILED) {
MS_LOG(EXCEPTION) << "operator: " << prim->name() << " init failed.";
}
return op_info;
}
std::pair<OperatorInfoPtr, TensorInfoPtr> PipelineTransformer::GetOpInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr op_info = nullptr;
TensorInfo tensor_info;
// op1(stage1)->op2(stage2)
if (IsValueNode<Primitive>(cnode->input(0))) {
op_info = CreateOpInfo(cnode);
MS_EXCEPTION_IF_NULL(op_info);
tensor_info = op_info->outputs_tensor_info()[0];
} else if (IsValueNode<FuncGraph>(cnode->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(graph);
auto output = graph->output();
MS_EXCEPTION_IF_NULL(output);
auto output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
auto prim = GetValueNode<PrimitivePtr>(output_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == TUPLE_GETITEM) {
auto index = GetTupleGetItemIndex(output_cnode);
auto pre_getitem_node = output_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_getitem_node);
op_info = CreateOpInfo(pre_getitem_node);
MS_EXCEPTION_IF_NULL(op_info);
tensor_info = op_info->outputs_tensor_info()[index];
} else {
op_info = CreateOpInfo(output_cnode);
MS_EXCEPTION_IF_NULL(op_info);
tensor_info = op_info->outputs_tensor_info()[0];
}
}
return std::make_pair(op_info, std::make_shared<TensorInfo>(tensor_info));
}
void PipelineTransformer::DoBroadCast(const FuncGraphPtr &func) {
auto need_coloring = true;
while (need_coloring) {
@ -168,26 +270,19 @@ void PipelineTransformer::ParameterColoring() {
}
}
static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node) {
abstract::ShapePtr shape_ptr;
static std::pair<ValueListPtr, TypePtr> GetShapeType(const AnfNodePtr &node, const Shape &shape) {
TypePtr type;
std::vector<int64_t> shape;
auto cnode = node->cast<CNodePtr>();
if (cnode != nullptr && IsValueNode<FuncGraph>(cnode->input(0))) {
auto graph = GetValueNode<FuncGraphPtr>(cnode->input(0));
auto graph_return = graph->get_return();
shape_ptr = dyn_cast<abstract::Shape>(graph_return->Shape());
type = graph_return->Type();
} else {
shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
type = node->Type();
}
MS_EXCEPTION_IF_NULL(shape_ptr);
MS_EXCEPTION_IF_NULL(type);
auto shape_int = shape_ptr->shape();
std::vector<ValuePtr> element;
std::transform(shape_int.begin(), shape_int.end(), std::back_inserter(element),
[](int elem) { return MakeValue(elem); });
std::transform(shape.begin(), shape.end(), std::back_inserter(element), [](int elem) { return MakeValue(elem); });
auto shape_list = std::make_shared<ValueList>(element);
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
@ -203,16 +298,20 @@ SendAttr PipelineTransformer::InsertSend(const FuncGraphPtr &graph, const AnfNod
auto dest_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
Attr attr_rank = std::make_pair("dest_rank", MakeValue(dest_rank));
OperatorAttrs attrs = {attr_tag, attr_rank};
auto send_op = CreatOpInstance(attrs, "Send", "send");
auto send_op = CreatOpInstance(attrs, SEND, "send");
auto send_node = NewValueNode(send_op);
auto prim = GetValueNode<PrimitivePtr>(send_node);
auto shape_type_pair = GetShapeType(parameter);
auto op_info_pair = GetOpInfo(parameter);
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape();
auto shape_type_pair = GetShapeType(parameter, slice_shape);
prim->set_attr("shape", shape_type_pair.first);
prim->set_attr("dtype", shape_type_pair.second);
std::vector<AnfNodePtr> send_input = {send_node, parameter};
auto send = graph->NewCNode(send_input);
OperatorAttrs depend_attrs;
auto depend_op = CreatOpInstance(depend_attrs, "Depend", "depend");
auto depend_op = CreatOpInstance(depend_attrs, DEPEND, "depend");
std::vector<AnfNodePtr> depend_input = {NewValueNode(depend_op), parameter, send};
auto depend = graph->NewCNode(depend_input);
SendAttr send_out = {shape_type_pair.first, shape_type_pair.second, depend};
@ -223,15 +322,23 @@ void PipelineTransformer::InsertReceive(const FuncGraphPtr &graph, const AnfNode
int index, int user_node_stage, int node_stage) {
Attr attr_tag = std::make_pair("sr_tag", MakeValue(recv_tag));
recv_tag += 1;
auto src_rank = global_rank_ + (user_node_stage - node_stage) * per_stage_rank_num_;
auto src_rank = global_rank_ - (user_node_stage - node_stage) * per_stage_rank_num_;
Attr attr_rank = std::make_pair("src_rank", MakeValue(src_rank));
auto shape_type_pair = GetShapeType(node);
auto op_info_pair = GetOpInfo(node);
auto tensor_info = op_info_pair.second;
MS_EXCEPTION_IF_NULL(tensor_info);
auto slice_shape = tensor_info->slice_shape();
auto shape_type_pair = GetShapeType(node, slice_shape);
Attr attr_shape = std::make_pair("shape", shape_type_pair.first);
Attr attr_dtype = std::make_pair("dtype", shape_type_pair.second);
OperatorAttrs attrs = {attr_tag, attr_rank, attr_shape, attr_dtype};
auto recv_op = CreatOpInstance(attrs, "Receive", "recv");
auto recv_op = CreatOpInstance(attrs, RECEIVE, "recv");
std::vector<AnfNodePtr> recv_input = {NewValueNode(recv_op), virtual_param_};
auto recv = graph->NewCNode(recv_input);
auto node_abstract = node->abstract();
recv->set_abstract(node_abstract);
recv->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_info->tensor_layout()));
recv->set_user_data<OperatorInfo>(op_info_pair.first);
manager_->SetEdge(use_node, index, recv);
}
@ -317,36 +424,10 @@ void PipelineTransformer::CutBorder(const FuncGraphPtr &graph) {
void PipelineTransformer::CutGraph() {
for (auto &fg : manager_->func_graphs()) {
if (fg == root_) {
ElimRootParameter();
continue;
}
CutBorder(fg);
}
}
void PipelineTransformer::ElimRootParameter() {
auto output = root_->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output);
auto prim = GetValueNode<PrimitivePtr>(output->input(0));
if (prim->name() == DEPEND) {
auto opt_cnode = output->input(2)->cast<CNodePtr>();
auto prim_make_tuple = GetValueNode<PrimitivePtr>(opt_cnode->input(0));
if (prim_make_tuple->name() == MAKE_TUPLE) {
std::vector<AnfNodePtr> new_node_input = {opt_cnode->input(0)};
for (auto &input : opt_cnode->inputs()) {
if (input->isa<CNode>()) {
if (IsStageNode(input->cast<CNodePtr>())) {
new_node_input.push_back(input);
}
}
}
auto new_node = root_->NewCNode(new_node_input);
manager_->Replace(opt_cnode, new_node);
}
}
}
bool PipelineTransformer::IsStageNode(const CNodePtr &node) {
for (auto &input : node->inputs()) {
if (input->isa<Parameter>()) {
@ -414,11 +495,16 @@ std::pair<CNodePtr, FuncGraphPtr> PipelineTransformer::FindSensNode() {
}
void PipelineTransformer::CoverSensShape() {
if (IsLastStage()) {
return;
}
auto sens_graph_pair = FindSensNode();
auto sens_cnode = sens_graph_pair.first;
MS_EXCEPTION_IF_NULL(sens_cnode);
OperatorAttrs attrs;
auto fill_op = CreatOpInstance(attrs, "Fill", "");
MS_EXCEPTION_IF_NULL(type_ptr_);
MS_EXCEPTION_IF_NULL(shape_);
std::vector<AnfNodePtr> fill_input = {NewValueNode(fill_op), NewValueNode(type_ptr_),
NewValueNode(MakeValue(shape_->value())), NewValueNode(0)};
auto fill = root_->NewCNode(fill_input);

View File

@ -19,13 +19,18 @@
#include <utility>
#include <string>
#include <memory>
#include "ir/value.h"
#include "ir/graph_utils.h"
#include "base/base.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/graph_util/generate_graph.h"
namespace mindspore {
namespace parallel {
using TensorLayoutPtr = std::shared_ptr<TensorLayout>;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;
typedef struct {
ValueListPtr shape;
TypePtr type;
@ -59,8 +64,10 @@ class PipelineTransformer {
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, int index,
int user_node_stage, int node_stage);
void CutBorder(const FuncGraphPtr &graph);
void ElimRootParameter();
bool IsStageNode(const CNodePtr &node);
std::pair<OperatorInfoPtr, TensorInfoPtr> GetOpInfo(const AnfNodePtr &node);
OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode);
bool IsPipelineCareNode(const CNodePtr &cnode);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
FuncGraphManagerPtr manager_;
int64_t stage_;

View File

@ -1752,7 +1752,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST) {
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) {
continue;
}
auto attrs = prim->attrs();
@ -2420,6 +2420,13 @@ std::vector<std::pair<CNodePtr, LossNodeInfo>> GetSensLossPairs(const FuncGraphP
return sens_loss_pairs;
}
bool IsLastStage() {
MS_EXCEPTION_IF_NULL(g_device_manager);
auto stage_num = g_device_manager->stage_num();
auto stage_id = g_device_manager->stage_id();
return ((stage_num - 1) == stage_id);
}
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager) {
MS_EXCEPTION_IF_NULL(root);
@ -2432,7 +2439,9 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
for (auto &pair : sens_loss_pairs) {
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
StepSplitSens(pair);
if (IsLastStage()) {
StepSplitSens(pair);
}
}
for (auto &node : all_nodes) {
@ -2448,13 +2457,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(distribute_operator);
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
if (!IsSomePrimitive(cnode, RECEIVE)) {
InsertForwardOps(distribute_operator, cnode);
}
// insert redistribution ops
StepRedistribution(cnode, distribute_operator, cnode, tensor_redistribution, cnode);
// insert backward ops
if (has_backward) {
if (has_backward && !IsSomePrimitive(cnode, RECEIVE)) {
BackwardCommunication(distribute_operator, cnode, sens_loss_pairs);
}
@ -2468,7 +2479,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>() || IsSomePrimitive(cnode, RECEIVE)) {
continue;
}
@ -2895,7 +2906,7 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN
for (auto &candidate : candidate_set) {
auto candidate_node = candidate.first;
auto c = candidate_node->cast<CNodePtr>();
if (c == nullptr || !c->has_user_data<OperatorInfo>()) {
if (c == nullptr || !c->has_user_data<OperatorInfo>() || IsSomePrimitive(c, RECEIVE)) {
continue;
}
(void)parameter_user_info.second.second.insert(candidate);

View File

@ -131,6 +131,10 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node);
void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
StrategyPtr GenerateBatchParallelStrategy(const OperatorInfoPtr operator_, const PrimitivePtr prim);
bool IsLastStage();
// Add node for whole graph
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);

View File

@ -21,6 +21,7 @@
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace pipeline {
@ -59,7 +60,7 @@ static int64_t InferStage(int64_t rank_id, int64_t stage_num, int64_t device_num
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL || parallel_mode != parallel::AUTO_PARALLEL) {
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL && parallel_mode != parallel::AUTO_PARALLEL) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
@ -80,6 +81,9 @@ bool PipelineSplit(const ResourcePtr &res) {
}
auto stage = InferStage(global_rank, stage_num, device_num);
auto per_stage_rank_num = device_num / stage_num;
if (parallel::ParallelInit() != parallel::SUCCESS) {
MS_LOG(EXCEPTION) << "parallel init failed.";
}
auto transformer =
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
// step1: Do color graph

View File

@ -20,9 +20,10 @@ from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
_GetTensorSlice, _MirrorOperator, ReduceOp,
ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap)
from .grad_base import bprop_getters
from ..operations._inner_ops import Send, Receive
@bprop_getters.register(AllReduce)

View File

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Unique, GatherD, Identity, SequenceMask)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice, Send, Receive,
_VirtualDiv, _GetTensorSlice,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

View File

@ -21,6 +21,7 @@ from ... import context
from ...common import dtype as mstype
from ..primitive import PrimitiveWithCheck, PrimitiveWithInfer, prim_attr_register
from ..operations.math_ops import _infer_shape_reduce
from ...communication.management import get_rank, GlobalComm, _get_group
class ExtractImagePatches(PrimitiveWithInfer):
@ -371,6 +372,116 @@ class MatrixDiagPart(PrimitiveWithInfer):
return out_shape
class Send(PrimitiveWithInfer):
"""
Send tensors from src_rank to the specified dest_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Send must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be received by the Receive op with the same "sr_tag".
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = ops.Depend()
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.sr_tag = sr_tag
self.group = group
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return x_shape
def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Receive must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be send by the Send op with the same "sr_tag".
src_rank (int): A required integer identifying the source rank.
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group")
>>>
>>> def construct(self):
>>> out = self.recv()
>>> return out
>>>
>>> net = Net()
>>> output = net()
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.tag = sr_tag
self.shape = shape
self.dtype = dtype
self.group = group
def infer_shape(self, x_shape=None):
return self.shape
def infer_dtype(self, x_dtype=None):
return self.dtype
class MatrixSetDiag(PrimitiveWithInfer):
r"""
Modifies the batched diagonal part of a batched tensor.

View File

@ -116,117 +116,6 @@ class AllReduce(PrimitiveWithInfer):
return x_dtype
class Send(PrimitiveWithInfer):
"""
Send tensors from src_rank to the specified dest_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Send must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be received by the Receive op with the same "sr_tag".
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = ops.Depend()
>>> self.send = ops.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.sr_tag = sr_tag
self.group = group
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return x_shape
def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Receive must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be send by the Send op with the same "sr_tag".
src_rank (int): A required integer identifying the source rank.
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as ops
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.recv = ops.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.recv(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.tag = sr_tag
self.shape = shape
self.dtype = dtype
self.group = group
def infer_shape(self, x_shape=None):
return self.shape
def infer_dtype(self, x_dtype=None):
return self.dtype
class AllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group.

View File

@ -21,6 +21,7 @@ from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size
from mindspore.ops import operations as P
from mindspore.ops.operations._inner_ops import Send, Receive
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
@ -38,7 +39,7 @@ class SendNet(nn.Cell):
super(SendNet, self).__init__()
self.x = Parameter(initializer(Tensor(x), x.shape), name='x')
self.depend = P.Depend()
self.send = P.Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP)
self.send = Send(sr_tag=0, dest_rank=rank+size//2, group=NCCL_WORLD_COMM_GROUP)
def construct(self):
out = self.depend(self.x, self.send(self.x))
@ -47,8 +48,8 @@ class SendNet(nn.Cell):
class RecvNet(nn.Cell):
def __init__(self):
super(RecvNet, self).__init__()
self.recv = P.Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32,
group=NCCL_WORLD_COMM_GROUP)
self.recv = Receive(sr_tag=0, src_rank=rank-size//2, shape=[3, 3, 3, 3], dtype=mstype.float32,
group=NCCL_WORLD_COMM_GROUP)
def construct(self):
out = self.recv()

View File

@ -1,91 +0,0 @@
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import context
from mindspore.common.api import _executor
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation(get_all=True)
class NetWithLoss(nn.Cell):
def __init__(self, network):
super(NetWithLoss, self).__init__()
self.loss = VirtualLoss()
self.network = network
def construct(self, x, y):
predict = self.network(x, y)
return self.loss(predict)
class GradWrap(nn.Cell):
def __init__(self, network):
super(GradWrap, self).__init__()
self.network = network
def construct(self, x, y):
return grad_all(self.network)(x, y)
class Net(nn.Cell):
def __init__(self, axis=0, stage1=0, stage2=0, strategy1=None, strategy2=None, shape=None, target=""):
super().__init__()
if shape is None:
shape = [64, 64]
self.gatherv2 = P.GatherV2().shard(strategy1).add_prim_attr("primitive_target", target)
self.mul = P.Mul().shard(strategy2)
self.index = Tensor(np.ones(shape), dtype=ms.int32)
self.gatherv2.set_stage(stage1)
self.mul.set_stage(stage2)
self.axis = axis
def construct(self, x, y):
out = self.gatherv2(x, self.index, self.axis)
out = self.mul(out, y)
return out
def test_gatherv2_semi_samestage1():
context.set_auto_parallel_context(device_num=8, global_rank=0, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 0, 0, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)
def test_gatherv2_semi_samestage2():
context.set_auto_parallel_context(device_num=8, global_rank=5, \
parallel_mode="semi_auto_parallel", pipeline_stages=2)
strategy1 = ((1, 2), (1, 1))
strategy2 = ((2, 1, 1), (2, 1, 1))
net = GradWrap(NetWithLoss(Net(0, 1, 1, strategy1, strategy2)))
net.set_auto_parallel()
x = Tensor(np.ones([64, 64]), dtype=ms.float32)
y = Tensor(np.ones([64, 64, 64]), dtype=ms.float32)
net.set_train()
_executor.compile(net, x, y)

View File

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer
from mindspore.train.model import Model
class DatasetLenet():
def __init__(self, data, label, length=3):
self.data = data
self.label = label
self.index = 1
self.length = length
def __iter__(self):
return self
def __next__(self):
if self.index >= self.length:
raise StopIteration
self.index += 1
return self.data, self.label
def reset(self):
self.index = 0
def get_dataset_size(self):
return 32
def get_repeat_count(self):
return 1
def get_batch_size(self):
return 32
def create_tuple_iterator(self, num_epochs=1):
return self
class MatMulCell(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.param = Parameter(initializer("zeros", [64, 64]), name="param")
self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
self.matmul = P.MatMul().shard(strategy1)
self.matmul1 = P.MatMul().shard(strategy2)
def construct(self, x):
out = self.matmul(x, self.param)
out = self.matmul1(out, self.param1)
return out
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.block = nn.CellList()
for i in range(2):
cell = MatMulCell(strategy1, strategy2)
cell.stage = i
self.block.append(cell)
def construct(self, x):
for i in range(2):
x = self.block[i](x)
return x
class PipelineSplit(nn.Cell):
def __init__(self, strategy1, strategy2):
super().__init__()
self.cell = Net(strategy1, strategy2)
def construct(self, x, label):
x = self.cell(x)
return x
def test_pipeline_split():
context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
data = Tensor(np.ones([32, 64]), dtype=ms.float32)
label = Tensor(np.ones([64, 64]), dtype=ms.float32)
strategy1 = ((4, 1), (1, 1))
strategy2 = ((2, 1), (1, 1))
net = PipelineSplit(strategy1, strategy2)
params = net.cell.block[1].trainable_params()
dataset = DatasetLenet(data, label, 3)
optimizer = nn.Lamb(params, learning_rate=0.01)
model = Model(net, optimizer=optimizer)
model.train(2, dataset, dataset_sink_mode=False)