!21207 modify replace graph for conv2d

Merge pull request !21207 from yangzhenzhang/fix-bugs-for-conv2d
This commit is contained in:
i-robot 2021-08-03 08:48:37 +00:00 committed by Gitee
commit 0a3b4ff84b
10 changed files with 229 additions and 68 deletions

View File

@ -100,7 +100,7 @@ AnfNodePtr CreatInt64Imm(int64_t value) {
return ValuePtrToAnfNodePtr(value_ptr);
}
AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple) {
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple) {
std::vector<ValuePtr> value_list;
std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
[](const int64_t value) { return MakeValue(value); });

View File

@ -41,7 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value);
AnfNodePtr CreatInt64Imm(int64_t value);
AnfNodePtr CreateInt32Tensor(int64_t value);
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple);
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
std::string HashInstanceName(const std::string &name);
class GenerateGraph {

View File

@ -148,6 +148,9 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
return FAILED;
}
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (pad_mode_ == 0) { // 'pad' mode
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
return FAILED;
@ -160,8 +163,6 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
}
if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape "
@ -177,24 +178,18 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
return FAILED;
}
if (kernel_size_[0] <= stride_[2]) {
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
if (h_slice_shape % stride_[2] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
if (kernel_size_[1] <= stride_[3]) {
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
if (w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
MS_LOG(ERROR) << name_
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
"not divisible by stride ";
return FAILED;
}
}
@ -234,6 +229,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
new_out_channel_ = out_channel_ / weight_strategy[0];
} else {
out_channel_shard_ = false;
new_out_channel_ = out_channel_;
}
return SUCCESS;
@ -527,7 +523,19 @@ void Conv2DInfo::InferOverlapShapes() {
right_recv_shape[3] = overlap_right_size_;
recv_shapes_.push_back(right_recv_shape);
}
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_;
if (left_need_send_) {
Shape left_send_shape = input_slice_shape_;
left_send_shape[3] = left_rank_overlap_right_size_;
send_shapes_.push_back(left_send_shape);
}
if (right_need_send_) {
Shape right_send_shape = input_slice_shape_;
right_send_shape[3] = right_rank_overlap_left_size_;
send_shapes_.push_back(right_send_shape);
}
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
}
void Conv2DInfo::InferStridedSliceAttrs() {
@ -536,9 +544,6 @@ void Conv2DInfo::InferStridedSliceAttrs() {
left_strided_slice_end_ = input_slice_shape_;
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
left_strided_slice_strides_ = {1, 1, 1, 1};
Shape left_send_shape = input_slice_shape_;
left_send_shape[3] = left_rank_overlap_right_size_;
send_shapes_.push_back(left_send_shape);
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
<< left_strided_slice_end_;
}
@ -548,9 +553,6 @@ void Conv2DInfo::InferStridedSliceAttrs() {
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
right_strided_slice_end_ = input_slice_shape_;
right_strided_slice_strides_ = {1, 1, 1, 1};
Shape right_send_shape = input_slice_shape_;
right_send_shape[3] = right_rank_overlap_left_size_;
send_shapes_.push_back(right_send_shape);
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
<< right_strided_slice_end_;
}
@ -566,7 +568,7 @@ void Conv2DInfo::InferNewOperatorAttrs() {
InferStridedSliceAttrs();
}
OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) {
OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
auto type = cnode->Type();
MS_EXCEPTION_IF_NULL(type);
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
@ -582,7 +584,7 @@ OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) {
return attrs;
}
OperatorAttrs Conv2DInfo::CreatConv2DAttrs() {
OperatorAttrs Conv2DInfo::CreateConv2DAttrs() {
Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
Attr mode = {MODE, MakeValue(mode_)};
@ -592,65 +594,130 @@ OperatorAttrs Conv2DInfo::CreatConv2DAttrs() {
Attr dilation = {DILATION, MakeValue(dilation_)};
Attr group = {GROUP, MakeValue(group_)};
Attr data_format = {DATA_FORMAT, MakeValue(format_)};
OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
OperatorAttrs attrs;
if (name_.find(CONV2D_INFO) != std::string::npos) {
attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
} else { // Conv2DTranspose
attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format};
}
return attrs;
}
std::string Conv2DInfo::ReplaceNodeName() {
if (name_.find(CONV2D_INFO) != std::string::npos) {
return CONV2D;
}
if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) {
return CONV2D_BACK_PROP_INPUT;
}
if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) {
return CONV2D_TRANSPOSE;
}
MS_LOG(EXCEPTION) << "Invalid name: " << name_;
}
AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) {
auto conv2d_attrs = CreateConv2DAttrs();
auto node_name = ReplaceNodeName();
// conv2d
if (name_.find(CONV2D_INFO) != std::string::npos) {
if (cnode->size() < 3) {
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
}
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)});
}
// conv2dtranspose
if (cnode->size() < 4) {
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
}
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)});
}
Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(ERROR) << "GenerateGraph Init failed";
return FAILED;
if (gen_g_.Init(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
}
if (!left_need_send_ && !right_need_send_) {
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
}
if (!left_need_recv_ && !right_need_recv_) {
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
}
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
if (left_need_send_) {
auto slice_left_begin = CreatTuple(left_strided_slice_begin_);
auto slice_left_end = CreatTuple(left_strided_slice_end_);
auto slice_left_strided = CreatTuple(left_strided_slice_strides_);
auto slice_left = gen_g.PushBack(
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided});
auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
auto slice_left_end = CreateTuple(left_strided_slice_end_);
auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
slice_left_end, slice_left_strided});
make_tuple_a_inputs.push_back(slice_left);
input_nodes.push_back(std::make_pair(slice_left, 1));
}
if (right_need_send_) {
auto slice_right_begin = CreatTuple(right_strided_slice_begin_);
auto slice_right_end = CreatTuple(right_strided_slice_end_);
auto slice_right_strided = CreatTuple(right_strided_slice_strides_);
auto slice_right = gen_g.PushBack(
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided});
auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
auto slice_right_end = CreateTuple(right_strided_slice_end_);
auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
slice_right_end, slice_right_strided});
make_tuple_a_inputs.push_back(slice_right);
input_nodes.push_back(std::make_pair(slice_right, 1));
}
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode);
auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
AnfNodePtr conv2d;
Attr concat_axis = {AXIS, MakeValue(-1)};
OperatorAttrs concat_attrs = {concat_axis};
if (left_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1),
tuple_getitem_l};
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
cnode->input(1)};
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l});
make_tuple_inputs.push_back(concat_l);
auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
if (right_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(1)};
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
conv2d = GenerateConv2DNode(concat_r, cnode);
} else {
conv2d = GenerateConv2DNode(concat_l, cnode);
}
} else { // left no need recv, and right need recv
std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
tuple_getitem_r_1};
auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
conv2d = GenerateConv2DNode(concat_r_1, cnode);
}
if (right_need_recv_) {
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
CreatInt64Imm(0)};
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
make_tuple_inputs.push_back(tuple_getitem_r);
} else {
make_tuple_inputs.push_back(cnode->input(1));
}
auto make_tuple = graph->NewCNode(make_tuple_inputs);
Attr concat_axis = {AXIS, MakeValue(-1)};
OperatorAttrs concat_attrs = {concat_axis};
std::vector<AnfNodePtr> concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple};
auto concat = graph->NewCNode(concat_inputs);
auto conv2d_attrs = CreatConv2DAttrs();
auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)});
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, conv2d));
return SUCCESS;

View File

@ -23,6 +23,7 @@
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
@ -57,9 +58,11 @@ class Conv2DInfo : public OperatorInfo {
void InferSendRecvFlag();
void InferOverlapShapes();
void InferStridedSliceAttrs();
std::string ReplaceNodeName();
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode);
OperatorAttrs CreatConv2DAttrs();
OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
OperatorAttrs CreateConv2DAttrs();
Status ComputeReplaceGraph(const CNodePtr &cnode);
int64_t out_channel_ = 1;
@ -106,6 +109,8 @@ class Conv2DInfo : public OperatorInfo {
Shapes send_shapes_;
Shapes recv_shapes_;
GenerateGraph gen_g_ = GenerateGraph(attrs_);
virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
virtual void InferNewPadList();
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);

View File

@ -172,6 +172,22 @@ Status GatherDInfo::InferMirrorOps() {
return SUCCESS;
}
void GatherDInfo::ReComputeBatchSplitFlagList() {
if (InferAttrs() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
}
if (dim_ == 0) {
MS_LOG(EXCEPTION)
<< name_
<< ": Can not generate batch data parallel strategy since the dim is 0, please set others strategy for it";
}
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
split_flag_list_[i] = true;
}
}
Status GatherDInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) {

View File

@ -40,6 +40,7 @@ class GatherDInfo : public OperatorInfo {
Status InitForCostModel(const StrategyPtr &strategy) override;
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override;

View File

@ -283,6 +283,9 @@ constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
constexpr char CONV2D[] = "Conv2D";
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
constexpr char CONV2D_INFO[] = "Conv2DInfo";
constexpr char CONV2D_BACK_PROP_INPUT_INFO[] = "Conv2DBackpropInputInfo";
constexpr char CONV2D_TRANSPOSE_INFO[] = "Conv2DTransposeInfo";
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
constexpr char BATCH_NORM[] = "BatchNorm";

View File

@ -39,6 +39,8 @@ class Net(Cell):
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -75,6 +77,31 @@ def test_conv2d_model_parallel2():
compile_net(net)
def test_conv2d_model_parallel3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_model_parallel4():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_left_and_right_no_need_to_send():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
with pytest.raises(RuntimeError):
compile_net(net)
def test_conv2d_output_can_not_divisible_by_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))

View File

@ -36,8 +36,24 @@ class Net(Cell):
return out
class Net2(Cell):
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
strategy1=None, strategy2=None):
super().__init__()
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
pad_mode=pad_mode, stride=stride).shard(strategy1)
self.neg = P.Neg().shard(strategy2)
self.weight = Parameter(conv2d_weight, "w1")
def construct(self, x, b):
out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16))
out = self.neg(out)
return out
_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
@ -64,3 +80,21 @@ def test_conv2d_transpose_model_parallel1():
strategy2 = ((8, 1, 1, 1),)
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_transpose_model_parallel2():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
strategy2 = ((2, 1, 1, 4),)
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)
def test_conv2d_transpose_model_parallel3():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
strategy2 = ((2, 2, 1, 4),)
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
strategy1=strategy1, strategy2=strategy2)
compile_net(net)

View File

@ -65,6 +65,14 @@ def test_gathernd_dim2():
compile_net(net)
def test_gathernd_dim2_default_batch_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = None
strategy2 = ((2, 8, 1),)
net = Net(2, _w1, strategy1, strategy2)
compile_net(net)
def test_gathernd_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(1, _w1)