forked from mindspore-Ecosystem/mindspore
!21207 modify replace graph for conv2d
Merge pull request !21207 from yangzhenzhang/fix-bugs-for-conv2d
This commit is contained in:
commit
0a3b4ff84b
|
@ -100,7 +100,7 @@ AnfNodePtr CreatInt64Imm(int64_t value) {
|
|||
return ValuePtrToAnfNodePtr(value_ptr);
|
||||
}
|
||||
|
||||
AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple) {
|
||||
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
std::transform(tuple.begin(), tuple.end(), std::back_inserter(value_list),
|
||||
[](const int64_t value) { return MakeValue(value); });
|
||||
|
|
|
@ -41,7 +41,7 @@ AnfNodePtr CreatTypeInt(int64_t value);
|
|||
AnfNodePtr CreatInt64Imm(int64_t value);
|
||||
AnfNodePtr CreateInt32Tensor(int64_t value);
|
||||
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr);
|
||||
AnfNodePtr CreatTuple(const std::vector<int64_t> &tuple);
|
||||
AnfNodePtr CreateTuple(const std::vector<int64_t> &tuple);
|
||||
std::string HashInstanceName(const std::string &name);
|
||||
|
||||
class GenerateGraph {
|
||||
|
|
|
@ -148,6 +148,9 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
|
||||
if (pad_mode_ == 0) { // 'pad' mode
|
||||
MS_LOG(ERROR) << name_ << ": The 'pad' mode do not support to split H or W";
|
||||
return FAILED;
|
||||
|
@ -160,8 +163,6 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
|||
}
|
||||
|
||||
if (kernel_size_[0] <= stride_[2] || kernel_size_[1] <= stride_[3]) {
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
if (h_slice_shape % stride_[2] != 0 || w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'same' mode do not support to split H or W when kernel_size <= stride but slice shape "
|
||||
|
@ -177,24 +178,18 @@ Status Conv2DInfo::CheckHWStrategy(int64_t h_strategy, int64_t w_strategy) {
|
|||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[0] <= stride_[2]) {
|
||||
int64_t h_slice_shape = inputs_shape_[0][2] / h_strategy;
|
||||
if (h_slice_shape % stride_[2] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
if (kernel_size_[0] <= stride_[2] && h_slice_shape % stride_[2] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split H when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
if (kernel_size_[1] <= stride_[3]) {
|
||||
int64_t w_slice_shape = inputs_shape_[0][3] / w_strategy;
|
||||
if (w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
if (kernel_size_[1] <= stride_[3] && w_slice_shape % stride_[3] != 0) {
|
||||
MS_LOG(ERROR) << name_
|
||||
<< ": The 'valid' mode do not support to split W when kernel_size <= stride but slice shape is "
|
||||
"not divisible by stride ";
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -234,6 +229,7 @@ Status Conv2DInfo::CheckStrategyBase(const StrategyPtr &strategy) {
|
|||
new_out_channel_ = out_channel_ / weight_strategy[0];
|
||||
} else {
|
||||
out_channel_shard_ = false;
|
||||
new_out_channel_ = out_channel_;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
|
@ -527,7 +523,19 @@ void Conv2DInfo::InferOverlapShapes() {
|
|||
right_recv_shape[3] = overlap_right_size_;
|
||||
recv_shapes_.push_back(right_recv_shape);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_;
|
||||
|
||||
if (left_need_send_) {
|
||||
Shape left_send_shape = input_slice_shape_;
|
||||
left_send_shape[3] = left_rank_overlap_right_size_;
|
||||
send_shapes_.push_back(left_send_shape);
|
||||
}
|
||||
|
||||
if (right_need_send_) {
|
||||
Shape right_send_shape = input_slice_shape_;
|
||||
right_send_shape[3] = right_rank_overlap_left_size_;
|
||||
send_shapes_.push_back(right_send_shape);
|
||||
}
|
||||
MS_LOG(INFO) << name_ << ": the recv shapes is " << recv_shapes_ << ", the send shapes is " << send_shapes_;
|
||||
}
|
||||
|
||||
void Conv2DInfo::InferStridedSliceAttrs() {
|
||||
|
@ -536,9 +544,6 @@ void Conv2DInfo::InferStridedSliceAttrs() {
|
|||
left_strided_slice_end_ = input_slice_shape_;
|
||||
left_strided_slice_end_[3] = left_rank_overlap_right_size_;
|
||||
left_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
Shape left_send_shape = input_slice_shape_;
|
||||
left_send_shape[3] = left_rank_overlap_right_size_;
|
||||
send_shapes_.push_back(left_send_shape);
|
||||
MS_LOG(INFO) << name_ << ": The left strided slice begin is " << left_strided_slice_begin_ << ", end is "
|
||||
<< left_strided_slice_end_;
|
||||
}
|
||||
|
@ -548,9 +553,6 @@ void Conv2DInfo::InferStridedSliceAttrs() {
|
|||
right_strided_slice_begin_[3] = input_slice_shape_[3] - right_rank_overlap_left_size_;
|
||||
right_strided_slice_end_ = input_slice_shape_;
|
||||
right_strided_slice_strides_ = {1, 1, 1, 1};
|
||||
Shape right_send_shape = input_slice_shape_;
|
||||
right_send_shape[3] = right_rank_overlap_left_size_;
|
||||
send_shapes_.push_back(right_send_shape);
|
||||
MS_LOG(INFO) << name_ << ": The right strided slice begin is " << right_strided_slice_begin_ << ", end is "
|
||||
<< right_strided_slice_end_;
|
||||
}
|
||||
|
@ -566,7 +568,7 @@ void Conv2DInfo::InferNewOperatorAttrs() {
|
|||
InferStridedSliceAttrs();
|
||||
}
|
||||
|
||||
OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) {
|
||||
OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
|
||||
auto type = cnode->Type();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto tensor_type = type->cast<mindspore::TensorTypePtr>();
|
||||
|
@ -582,7 +584,7 @@ OperatorAttrs Conv2DInfo::CreatNeighborExchangeAttrs(const CNodePtr &cnode) {
|
|||
return attrs;
|
||||
}
|
||||
|
||||
OperatorAttrs Conv2DInfo::CreatConv2DAttrs() {
|
||||
OperatorAttrs Conv2DInfo::CreateConv2DAttrs() {
|
||||
Attr out_channel = {OUT_CHANNEL, MakeValue(new_out_channel_)};
|
||||
Attr kernel_size = {KERNEL_SIZE, MakeValue(kernel_size_)};
|
||||
Attr mode = {MODE, MakeValue(mode_)};
|
||||
|
@ -592,65 +594,130 @@ OperatorAttrs Conv2DInfo::CreatConv2DAttrs() {
|
|||
Attr dilation = {DILATION, MakeValue(dilation_)};
|
||||
Attr group = {GROUP, MakeValue(group_)};
|
||||
Attr data_format = {DATA_FORMAT, MakeValue(format_)};
|
||||
OperatorAttrs attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
|
||||
|
||||
OperatorAttrs attrs;
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) {
|
||||
attrs = {out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format};
|
||||
} else { // Conv2DTranspose
|
||||
attrs = {out_channel, kernel_size, pad_mode, pad, pad, mode, stride, dilation, group, data_format};
|
||||
}
|
||||
|
||||
return attrs;
|
||||
}
|
||||
|
||||
std::string Conv2DInfo::ReplaceNodeName() {
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) {
|
||||
return CONV2D;
|
||||
}
|
||||
|
||||
if (name_.find(CONV2D_BACK_PROP_INPUT_INFO) != std::string::npos) {
|
||||
return CONV2D_BACK_PROP_INPUT;
|
||||
}
|
||||
|
||||
if (name_.find(CONV2D_TRANSPOSE_INFO) != std::string::npos) {
|
||||
return CONV2D_TRANSPOSE;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Invalid name: " << name_;
|
||||
}
|
||||
|
||||
AnfNodePtr Conv2DInfo::GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode) {
|
||||
auto conv2d_attrs = CreateConv2DAttrs();
|
||||
auto node_name = ReplaceNodeName();
|
||||
|
||||
// conv2d
|
||||
if (name_.find(CONV2D_INFO) != std::string::npos) {
|
||||
if (cnode->size() < 3) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
|
||||
}
|
||||
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2)});
|
||||
}
|
||||
|
||||
// conv2dtranspose
|
||||
if (cnode->size() < 4) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": The size of cnode is invalid: " << cnode->size();
|
||||
}
|
||||
return gen_g_.PushBack({gen_g_.NewOpInst(node_name, conv2d_attrs), new_input, cnode->input(2), cnode->input(3)});
|
||||
}
|
||||
|
||||
Status Conv2DInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
|
||||
auto graph = cnode->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
GenerateGraph gen_g = GenerateGraph(attrs_);
|
||||
if (gen_g.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(ERROR) << "GenerateGraph Init failed";
|
||||
return FAILED;
|
||||
|
||||
if (gen_g_.Init(cnode) != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "GenerateGraph Init failed";
|
||||
}
|
||||
|
||||
if (!left_need_send_ && !right_need_send_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to send and right no need to send";
|
||||
}
|
||||
|
||||
if (!left_need_recv_ && !right_need_recv_) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Now do not support left no need to recv and right no need to recv";
|
||||
}
|
||||
|
||||
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes;
|
||||
std::vector<AnfNodePtr> make_tuple_a_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
if (left_need_send_) {
|
||||
auto slice_left_begin = CreatTuple(left_strided_slice_begin_);
|
||||
auto slice_left_end = CreatTuple(left_strided_slice_end_);
|
||||
auto slice_left_strided = CreatTuple(left_strided_slice_strides_);
|
||||
auto slice_left = gen_g.PushBack(
|
||||
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_left_begin, slice_left_end, slice_left_strided});
|
||||
auto slice_left_begin = CreateTuple(left_strided_slice_begin_);
|
||||
auto slice_left_end = CreateTuple(left_strided_slice_end_);
|
||||
auto slice_left_strided = CreateTuple(left_strided_slice_strides_);
|
||||
auto slice_left = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_left_begin,
|
||||
slice_left_end, slice_left_strided});
|
||||
make_tuple_a_inputs.push_back(slice_left);
|
||||
input_nodes.push_back(std::make_pair(slice_left, 1));
|
||||
}
|
||||
if (right_need_send_) {
|
||||
auto slice_right_begin = CreatTuple(right_strided_slice_begin_);
|
||||
auto slice_right_end = CreatTuple(right_strided_slice_end_);
|
||||
auto slice_right_strided = CreatTuple(right_strided_slice_strides_);
|
||||
auto slice_right = gen_g.PushBack(
|
||||
{gen_g.NewOpInst(STRIDED_SLICE), cnode->input(1), slice_right_begin, slice_right_end, slice_right_strided});
|
||||
auto slice_right_begin = CreateTuple(right_strided_slice_begin_);
|
||||
auto slice_right_end = CreateTuple(right_strided_slice_end_);
|
||||
auto slice_right_strided = CreateTuple(right_strided_slice_strides_);
|
||||
auto slice_right = gen_g_.PushBack({gen_g_.NewOpInst(STRIDED_SLICE), gen_g_.virtual_input_node(), slice_right_begin,
|
||||
slice_right_end, slice_right_strided});
|
||||
make_tuple_a_inputs.push_back(slice_right);
|
||||
input_nodes.push_back(std::make_pair(slice_right, 1));
|
||||
}
|
||||
|
||||
auto make_tuple_a = graph->NewCNode(make_tuple_a_inputs);
|
||||
auto alltoall_attrs = CreatNeighborExchangeAttrs(cnode);
|
||||
auto alltoall_v = gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
|
||||
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
auto alltoall_attrs = CreateNeighborExchangeAttrs(cnode);
|
||||
auto alltoall_v = gen_g_.PushBack({gen_g_.NewOpInst(NEIGHBOREXCHANGE, alltoall_attrs), make_tuple_a});
|
||||
|
||||
AnfNodePtr conv2d;
|
||||
Attr concat_axis = {AXIS, MakeValue(-1)};
|
||||
OperatorAttrs concat_attrs = {concat_axis};
|
||||
|
||||
if (left_need_recv_) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_l_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(0)};
|
||||
auto tuple_getitem_l = graph->NewCNode(tuple_getitem_l_inputs);
|
||||
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), cnode->input(1),
|
||||
tuple_getitem_l};
|
||||
std::vector<AnfNodePtr> make_tuple_l_inputs = {NewValueNode(prim::kPrimMakeTuple), tuple_getitem_l,
|
||||
cnode->input(1)};
|
||||
auto make_tuple_l = graph->NewCNode(make_tuple_l_inputs);
|
||||
auto concat_l = gen_g.PushBack({gen_g.NewOpInst(CONCAT), make_tuple_l});
|
||||
make_tuple_inputs.push_back(concat_l);
|
||||
auto concat_l = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_l});
|
||||
|
||||
if (right_need_recv_) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(1)};
|
||||
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
|
||||
std::vector<AnfNodePtr> make_tuple_r_inputs = {NewValueNode(prim::kPrimMakeTuple), concat_l, tuple_getitem_r};
|
||||
auto make_tuple_r = graph->NewCNode(make_tuple_r_inputs);
|
||||
auto concat_r = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r});
|
||||
conv2d = GenerateConv2DNode(concat_r, cnode);
|
||||
} else {
|
||||
conv2d = GenerateConv2DNode(concat_l, cnode);
|
||||
}
|
||||
} else { // left no need recv, and right need recv
|
||||
std::vector<AnfNodePtr> tuple_getitem_r_inputs_1 = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(0)};
|
||||
auto tuple_getitem_r_1 = graph->NewCNode(tuple_getitem_r_inputs_1);
|
||||
std::vector<AnfNodePtr> make_tuple_r_inputs_1 = {NewValueNode(prim::kPrimMakeTuple), gen_g_.virtual_input_node(),
|
||||
tuple_getitem_r_1};
|
||||
auto make_tuple_r_1 = graph->NewCNode(make_tuple_r_inputs_1);
|
||||
input_nodes.push_back(std::make_pair(make_tuple_r_1, 1));
|
||||
|
||||
auto concat_r_1 = gen_g_.PushBack({gen_g_.NewOpInst(CONCAT, concat_attrs), make_tuple_r_1});
|
||||
conv2d = GenerateConv2DNode(concat_r_1, cnode);
|
||||
}
|
||||
if (right_need_recv_) {
|
||||
std::vector<AnfNodePtr> tuple_getitem_r_inputs = {NewValueNode(prim::kPrimTupleGetItem), alltoall_v,
|
||||
CreatInt64Imm(0)};
|
||||
auto tuple_getitem_r = graph->NewCNode(tuple_getitem_r_inputs);
|
||||
make_tuple_inputs.push_back(tuple_getitem_r);
|
||||
} else {
|
||||
make_tuple_inputs.push_back(cnode->input(1));
|
||||
}
|
||||
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
||||
Attr concat_axis = {AXIS, MakeValue(-1)};
|
||||
OperatorAttrs concat_attrs = {concat_axis};
|
||||
std::vector<AnfNodePtr> concat_inputs = {gen_g.NewOpInst(CONCAT, concat_attrs), make_tuple};
|
||||
auto concat = graph->NewCNode(concat_inputs);
|
||||
auto conv2d_attrs = CreatConv2DAttrs();
|
||||
auto conv2d = gen_g.PushBack({gen_g.NewOpInst(CONV2D, conv2d_attrs), concat, cnode->input(2)});
|
||||
|
||||
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
|
||||
std::make_pair(input_nodes, conv2d));
|
||||
return SUCCESS;
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "ir/value.h"
|
||||
#include "frontend/parallel/graph_util/generate_graph.h"
|
||||
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
|
||||
#include "frontend/parallel/ops_info/operator_info.h"
|
||||
#include "frontend/parallel/strategy.h"
|
||||
|
@ -57,9 +58,11 @@ class Conv2DInfo : public OperatorInfo {
|
|||
void InferSendRecvFlag();
|
||||
void InferOverlapShapes();
|
||||
void InferStridedSliceAttrs();
|
||||
std::string ReplaceNodeName();
|
||||
AnfNodePtr GenerateConv2DNode(const AnfNodePtr &new_input, const CNodePtr &cnode);
|
||||
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
|
||||
OperatorAttrs CreatNeighborExchangeAttrs(const CNodePtr &cnode);
|
||||
OperatorAttrs CreatConv2DAttrs();
|
||||
OperatorAttrs CreateNeighborExchangeAttrs(const CNodePtr &cnode);
|
||||
OperatorAttrs CreateConv2DAttrs();
|
||||
Status ComputeReplaceGraph(const CNodePtr &cnode);
|
||||
|
||||
int64_t out_channel_ = 1;
|
||||
|
@ -106,6 +109,8 @@ class Conv2DInfo : public OperatorInfo {
|
|||
Shapes send_shapes_;
|
||||
Shapes recv_shapes_;
|
||||
|
||||
GenerateGraph gen_g_ = GenerateGraph(attrs_);
|
||||
|
||||
virtual Status CheckHWStrategy(int64_t h_strategy, int64_t w_strategy);
|
||||
virtual void InferNewPadList();
|
||||
virtual int64_t ComputeOverlapLeftSizeByRankBias(int64_t rank_bias);
|
||||
|
|
|
@ -172,6 +172,22 @@ Status GatherDInfo::InferMirrorOps() {
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
void GatherDInfo::ReComputeBatchSplitFlagList() {
|
||||
if (InferAttrs() != SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << name_ << ": Infer attrs failed";
|
||||
}
|
||||
|
||||
if (dim_ == 0) {
|
||||
MS_LOG(EXCEPTION)
|
||||
<< name_
|
||||
<< ": Can not generate batch data parallel strategy since the dim is 0, please set others strategy for it";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
|
||||
split_flag_list_[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
Status GatherDInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { return SetCostUnderStrategyBase(strategy); }
|
||||
|
||||
std::vector<StrategyPtr> GatherDInfo::GenerateOpStrategies(int64_t stage_id) {
|
||||
|
|
|
@ -40,6 +40,7 @@ class GatherDInfo : public OperatorInfo {
|
|||
Status InitForCostModel(const StrategyPtr &strategy) override;
|
||||
std::vector<StrategyPtr> GenerateOpStrategies(int64_t) override;
|
||||
Status SetCostUnderStrategy(const StrategyPtr &) override;
|
||||
void ReComputeBatchSplitFlagList() override;
|
||||
|
||||
protected:
|
||||
Status GetAttrs() override;
|
||||
|
|
|
@ -283,6 +283,9 @@ constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
|
|||
constexpr char CONV2D[] = "Conv2D";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT[] = "Conv2DBackpropInput";
|
||||
constexpr char CONV2D_TRANSPOSE[] = "Conv2DTranspose";
|
||||
constexpr char CONV2D_INFO[] = "Conv2DInfo";
|
||||
constexpr char CONV2D_BACK_PROP_INPUT_INFO[] = "Conv2DBackpropInputInfo";
|
||||
constexpr char CONV2D_TRANSPOSE_INFO[] = "Conv2DTransposeInfo";
|
||||
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
|
||||
constexpr char FUSE_BATCH_NORM_EX[] = "FusedBatchNormEx";
|
||||
constexpr char BATCH_NORM[] = "BatchNorm";
|
||||
|
|
|
@ -39,6 +39,8 @@ class Net(Cell):
|
|||
|
||||
_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
|
||||
_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -75,6 +77,31 @@ def test_conv2d_model_parallel2():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel3():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 4),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_model_parallel4():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_left_and_right_no_need_to_send():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 4),)
|
||||
net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2)
|
||||
with pytest.raises(RuntimeError):
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_output_can_not_divisible_by_strategy():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
|
||||
|
|
|
@ -36,8 +36,24 @@ class Net(Cell):
|
|||
return out
|
||||
|
||||
|
||||
class Net2(Cell):
|
||||
def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
|
||||
strategy1=None, strategy2=None):
|
||||
super().__init__()
|
||||
self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
|
||||
pad_mode=pad_mode, stride=stride).shard(strategy1)
|
||||
self.neg = P.Neg().shard(strategy2)
|
||||
self.weight = Parameter(conv2d_weight, "w1")
|
||||
|
||||
def construct(self, x, b):
|
||||
out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16))
|
||||
out = self.neg(out)
|
||||
return out
|
||||
|
||||
|
||||
_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
|
||||
_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
|
||||
_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
|
||||
_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
|
||||
|
||||
|
||||
|
@ -64,3 +80,21 @@ def test_conv2d_transpose_model_parallel1():
|
|||
strategy2 = ((8, 1, 1, 1),)
|
||||
net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_model_parallel2():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
|
||||
strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
|
||||
strategy2 = ((2, 1, 1, 4),)
|
||||
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_conv2d_transpose_model_parallel3():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
|
||||
strategy2 = ((2, 2, 1, 4),)
|
||||
net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
|
||||
strategy1=strategy1, strategy2=strategy2)
|
||||
compile_net(net)
|
||||
|
|
|
@ -65,6 +65,14 @@ def test_gathernd_dim2():
|
|||
compile_net(net)
|
||||
|
||||
|
||||
def test_gathernd_dim2_default_batch_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
|
||||
strategy1 = None
|
||||
strategy2 = ((2, 8, 1),)
|
||||
net = Net(2, _w1, strategy1, strategy2)
|
||||
compile_net(net)
|
||||
|
||||
|
||||
def test_gathernd_auto_parallel():
|
||||
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
|
||||
net = Net(1, _w1)
|
||||
|
|
Loading…
Reference in New Issue