diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index 113227e56e3..2658c3042a2 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -100,7 +100,7 @@ AnfNodePtr CreatInt64Imm(int64_t value) { return ValuePtrToAnfNodePtr(value_ptr); } -AnfNodePtr CreatTuple(const std::vector &tuple) { +AnfNodePtr CreateTuple(const std::vector &tuple) { std::vector value_list; std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list), [](const int64_t value) { return MakeValue(value); }); diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h index 55801c0af5f..12c0c6bc157 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.h @@ -41,7 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value); AnfNodePtr CreatInt64Imm(int64_t value); AnfNodePtr CreateInt32Tensor(int64_t value); AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr); -AnfNodePtr CreatTuple(const std::vector &tuple); +AnfNodePtr CreateTuple(const std::vector &tuple); std::string HashInstanceName(const std::string &name); class GenerateGraph { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc index 8fc52daed14..39d998aa2aa 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.cc @@ -148,6 +148,9 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return FAILED; } + int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; + int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; + if (pad_mode_ == 0) { // 'pad' mode MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W"; return FAILED; @@ -160,8 +163,6 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { } if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) { - int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; - int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) { MS_LOG(ERROR) << name_ << ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape " @@ -177,24 +178,18 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) { return FAILED; } - if (kernel_size_[0] <= stride_[2]) { - int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy; - if (h_slice_shape % stride_[2] != 0) { - MS_LOG(ERROR) << name_ - << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " - "not divisible by stride "; - return FAILED; - } + if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) { + MS_LOG(ERROR) << name_ + << ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is " + "not divisible by stride "; + return FAILED; } - if (kernel_size_[1] <= stride_[3]) { - int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy; - if (w_slice_shape % stride_[3] != 0) { - MS_LOG(ERROR) << name_ - << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " - "not divisible by stride "; - return FAILED; - } + if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) { + MS_LOG(ERROR) << name_ + << ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is " + "not divisible by stride "; + return FAILED; } } @@ -234,6 +229,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) { new_out_channel_ = out_channel_ / weight_strategy[0]; } else { out_channel_shard_ = false; + new_out_channel_ = out_channel_; } return SUCCESS; @@ -527,7 +523,19 @@ void Conv2DInfo::InferOverlapShapes() { right_recv_shape[3] = overlap_right_size_; recv_shapes_.push_back(right_recv_shape); } - MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_; + + if (left_need_send_) { + Shape left_send_shape = input_slice_shape_; + left_send_shape[3] = left_rank_overlap_right_size_; + send_shapes_.push_back(left_send_shape); + } + + if (right_need_send_) { + Shape right_send_shape = input_slice_shape_; + right_send_shape[3] = right_rank_overlap_left_size_; + send_shapes_.push_back(right_send_shape); + } + MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_; } void Conv2DInfo::InferStridedSliceAttrs() { @@ -536,9 +544,6 @@ void Conv2DInfo::InferStridedSliceAttrs() { left_strided_slice_end_ = input_slice_shape_; left_strided_slice_end_[3] = left_rank_overlap_right_size_; left_strided_slice_strides_ = {1, 1, 1, 1}; - Shape left_send_shape = input_slice_shape_; - left_send_shape[3] = left_rank_overlap_right_size_; - send_shapes_.push_back(left_send_shape); MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is " << left_strided_slice_end_; } @@ -548,9 +553,6 @@ void Conv2DInfo::InferStridedSliceAttrs() { right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_; right_strided_slice_end_ = input_slice_shape_; right_strided_slice_strides_ = {1, 1, 1, 1}; - Shape right_send_shape = input_slice_shape_; - right_send_shape[3] = right_rank_overlap_left_size_; - send_shapes_.push_back(right_send_shape); MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is " << right_strided_slice_end_; } @@ -566,7 +568,7 @@ void Conv2DInfo::InferNewOperatorAttrs() { InferStridedSliceAttrs(); } -OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { +OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) { auto type = cnode->Type(); MS_EXCEPTION_IF_NULL(type); auto tensor_type = type->cast(); @@ -582,7 +584,7 @@ OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) { return attrs; } -OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { +OperatorAttrs Conv2DInfo::CreateConv2DAttrs() { Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)}; Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)}; Attr mode = {MODE, MakeValue(mode_)}; @@ -592,65 +594,130 @@ OperatorAttrs Conv2DInfo::CreatConv2DAttrs() { Attr dilation = {DILATION, MakeValue(dilation_)}; Attr group = {GROUP, MakeValue(group_)}; Attr data_format = {DATA_FORMAT, MakeValue(format_)}; - OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; + + OperatorAttrs attrs; + if (name_.find(CONV2D_INFO) != std::string::npos) { + attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format}; + } else { // Conv2DTranspose + attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format}; + } + return attrs; } +std::string Conv2DInfo::ReplaceNodeName() { + if (name_.find(CONV2D_INFO) != std::string::npos) { + return CONV2D; + } + + if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) { + return CONV2D_BACK_PROP_INPUT; + } + + if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) { + return CONV2D_TRANSPOSE; + } + + MS_LOG(EXCEPTION) << "Invalid name: " << name_; +} + +AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) { + auto conv2d_attrs = CreateConv2DAttrs(); + auto node_name = ReplaceNodeName(); + + // conv2d + if (name_.find(CONV2D_INFO) != std::string::npos) { + if (cnode->size() < 3) { + MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size(); + } + return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)}); + } + + // conv2dtranspose + if (cnode->size() < 4) { + MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size(); + } + return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)}); +} + Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) { auto graph = cnode->func_graph(); MS_EXCEPTION_IF_NULL(graph); - GenerateGraph gen_g = GenerateGraph(attrs_); - if (gen_g.Init(cnode) != SUCCESS) { - MS_LOG(ERROR) << "GenerateGraph Init failed"; - return FAILED; + + if (gen_g_.Init(cnode) != SUCCESS) { + MS_LOG(EXCEPTION) << "GenerateGraph Init failed"; } + + if (!left_need_send_ && !right_need_send_) { + MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send"; + } + + if (!left_need_recv_ && !right_need_recv_) { + MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv"; + } + std::vector> input_nodes; std::vector make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)}; if (left_need_send_) { - auto slice_left_begin = CreatTuple(left_strided_slice_begin_); - auto slice_left_end = CreatTuple(left_strided_slice_end_); - auto slice_left_strided = CreatTuple(left_strided_slice_strides_); - auto slice_left = gen_g.PushBack( - {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided}); + auto slice_left_begin = CreateTuple(left_strided_slice_begin_); + auto slice_left_end = CreateTuple(left_strided_slice_end_); + auto slice_left_strided = CreateTuple(left_strided_slice_strides_); + auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin, + slice_left_end, slice_left_strided}); make_tuple_a_inputs.push_back(slice_left); + input_nodes.push_back(std::make_pair(slice_left, 1)); } if (right_need_send_) { - auto slice_right_begin = CreatTuple(right_strided_slice_begin_); - auto slice_right_end = CreatTuple(right_strided_slice_end_); - auto slice_right_strided = CreatTuple(right_strided_slice_strides_); - auto slice_right = gen_g.PushBack( - {gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided}); + auto slice_right_begin = CreateTuple(right_strided_slice_begin_); + auto slice_right_end = CreateTuple(right_strided_slice_end_); + auto slice_right_strided = CreateTuple(right_strided_slice_strides_); + auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin, + slice_right_end, slice_right_strided}); make_tuple_a_inputs.push_back(slice_right); + input_nodes.push_back(std::make_pair(slice_right, 1)); } + auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs); - auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode); - auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; + auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode); + auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a}); + + AnfNodePtr conv2d; + Attr concat_axis = {AXIS, MakeValue(-1)}; + OperatorAttrs concat_attrs = {concat_axis}; + if (left_need_recv_) { std::vector tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, CreatInt64Imm(0)}; auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs); - std::vector make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1), - tuple_getitem_l}; + std::vector make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l, + cnode->input(1)}; auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs); - auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l}); - make_tuple_inputs.push_back(concat_l); + auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l}); + + if (right_need_recv_) { + std::vector tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, + CreatInt64Imm(1)}; + auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); + std::vector make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r}; + auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs); + auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r}); + conv2d = GenerateConv2DNode(concat_r, cnode); + } else { + conv2d = GenerateConv2DNode(concat_l, cnode); + } + } else { // left no need recv, and right need recv + std::vector tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, + CreatInt64Imm(0)}; + auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1); + std::vector make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(), + tuple_getitem_r_1}; + auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1); + input_nodes.push_back(std::make_pair(make_tuple_r_1, 1)); + + auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1}); + conv2d = GenerateConv2DNode(concat_r_1, cnode); } - if (right_need_recv_) { - std::vector tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v, - CreatInt64Imm(0)}; - auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs); - make_tuple_inputs.push_back(tuple_getitem_r); - } else { - make_tuple_inputs.push_back(cnode->input(1)); - } - auto make_tuple = graph->NewCNode(make_tuple_inputs); - Attr concat_axis = {AXIS, MakeValue(-1)}; - OperatorAttrs concat_attrs = {concat_axis}; - std::vector concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple}; - auto concat = graph->NewCNode(concat_inputs); - auto conv2d_attrs = CreatConv2DAttrs(); - auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)}); + replace_graph_ = std::make_shared>, AnfNodePtr>>( std::make_pair(input_nodes, conv2d)); return SUCCESS; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h index 1ae1e4a752a..3786dc5f826 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/conv2d_info.h @@ -23,6 +23,7 @@ #include #include "ir/value.h" +#include "frontend/parallel/graph_util/generate_graph.h" #include "frontend/parallel/auto_parallel/operator_costmodel.h" #include "frontend/parallel/ops_info/operator_info.h" #include "frontend/parallel/strategy.h" @@ -57,9 +58,11 @@ class Conv2DInfo : public OperatorInfo { void InferSendRecvFlag(); void InferOverlapShapes(); void InferStridedSliceAttrs(); + std::string ReplaceNodeName(); + AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode); ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; - OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode); - OperatorAttrs CreatConv2DAttrs(); + OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode); + OperatorAttrs CreateConv2DAttrs(); Status ComputeReplaceGraph(const CNodePtr &cnode); int64_t out_channel_ = 1; @@ -106,6 +109,8 @@ class Conv2DInfo : public OperatorInfo { Shapes send_shapes_; Shapes recv_shapes_; + GenerateGraph gen_g_ = GenerateGraph(attrs_); + virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy); virtual void InferNewPadList(); virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias); diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc index 35cd2405c03..64a2a0b3b83 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.cc @@ -172,6 +172,22 @@ Status GatherDInfo::InferMirrorOps() { return SUCCESS; } +void GatherDInfo::ReComputeBatchSplitFlagList() { + if (InferAttrs() != SUCCESS) { + MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed"; + } + + if (dim_ == 0) { + MS_LOG(EXCEPTION) + << name_ + << ": Can not generate batch data parallel strategy since the dim is 0, please set others strategy for it"; + } + + for (size_t i = 0; i < inputs_shape_.size(); ++i) { + split_flag_list_[i] = true; + } +} + Status GatherDInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); } std::vector GatherDInfo::GenerateOpStrategies(int64_t stage_id) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h index 8288fe11ae1..1d8a2fe24d2 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/gatherd_info.h @@ -40,6 +40,7 @@ class GatherDInfo : public OperatorInfo { Status InitForCostModel(const StrategyPtr &strategy) override; std::vector GenerateOpStrategies(int64_t) override; Status SetCostUnderStrategy(const StrategyPtr &) override; + void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index dd6a3237da5..c0c89beb245 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -283,6 +283,9 @@ constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue"; constexpr char CONV2D[] = "Conv2D"; constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput"; constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose"; +constexpr char CONV2D_INFO[] = "Conv2DInfo"; +constexpr char CONV2D_BACK_PROP_INPUT_INFO[] = "Conv2DBackpropInputInfo"; +constexpr char CONV2D_TRANSPOSE_INFO[] = "Conv2DTransposeInfo"; constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm"; constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx"; constexpr char BATCH_NORM[] = "BatchNorm"; diff --git a/tests/ut/python/parallel/test_conv2d.py b/tests/ut/python/parallel/test_conv2d.py index 4309b707513..1ef971a0587 100644 --- a/tests/ut/python/parallel/test_conv2d.py +++ b/tests/ut/python/parallel/test_conv2d.py @@ -39,6 +39,8 @@ class Net(Cell): _x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) +_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) +_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32) _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) @@ -75,6 +77,31 @@ def test_conv2d_model_parallel2(): compile_net(net) +def test_conv2d_model_parallel3(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) + strategy2 = ((2, 1, 1, 4),) + net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_conv2d_model_parallel4(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) + strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1)) + strategy2 = ((2, 2, 1, 4),) + net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_conv2d_left_and_right_no_need_to_send(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) + strategy2 = ((2, 1, 1, 4),) + net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) + with pytest.raises(RuntimeError): + compile_net(net) + + def test_conv2d_output_can_not_divisible_by_strategy(): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) diff --git a/tests/ut/python/parallel/test_conv2d_transpose.py b/tests/ut/python/parallel/test_conv2d_transpose.py index e5cc5d12027..46b65a2ea86 100644 --- a/tests/ut/python/parallel/test_conv2d_transpose.py +++ b/tests/ut/python/parallel/test_conv2d_transpose.py @@ -36,8 +36,24 @@ class Net(Cell): return out +class Net2(Cell): + def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, + strategy1=None, strategy2=None): + super().__init__() + self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size, + pad_mode=pad_mode, stride=stride).shard(strategy1) + self.neg = P.Neg().shard(strategy2) + self.weight = Parameter(conv2d_weight, "w1") + + def construct(self, x, b): + out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16)) + out = self.neg(out) + return out + + _x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32) _w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) +_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32) _b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) @@ -64,3 +80,21 @@ def test_conv2d_transpose_model_parallel1(): strategy2 = ((8, 1, 1, 1),) net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) compile_net(net) + + +def test_conv2d_transpose_model_parallel2(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) + strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) + strategy2 = ((2, 1, 1, 4),) + net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) + + +def test_conv2d_transpose_model_parallel3(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) + strategy2 = ((2, 2, 1, 4),) + net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, + strategy1=strategy1, strategy2=strategy2) + compile_net(net) diff --git a/tests/ut/python/parallel/test_gatherd.py b/tests/ut/python/parallel/test_gatherd.py index 2ee2a9c7964..abdcdd69391 100644 --- a/tests/ut/python/parallel/test_gatherd.py +++ b/tests/ut/python/parallel/test_gatherd.py @@ -65,6 +65,14 @@ def test_gathernd_dim2(): compile_net(net) +def test_gathernd_dim2_default_batch_parallel(): + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) + strategy1 = None + strategy2 = ((2, 8, 1),) + net = Net(2, _w1, strategy1, strategy2) + compile_net(net) + + def test_gathernd_auto_parallel(): context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) net = Net(1, _w1)