add parallel op for resize bilinear

This commit is contained in:
yangzhenzhang 2021-11-30 11:58:49 +08:00
parent 7ebfbb0278
commit 846db9206f
7 changed files with 386 additions and 29 deletions

View File

@ -31,20 +31,6 @@
namespace mindspore {
namespace parallel {
namespace {
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
std::vector<ValuePtr> list;
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
return std::make_shared<ValueSequeue>(list);
}
ValuePtr MakeTupleListValue(const Shapes &v) {
std::vector<ValuePtr> tuple;
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
return std::make_shared<ValueTuple>(tuple);
}
} // namespace
Status Conv2DInfo::GetAttrsBase() {
// format
format_ = GetStringAttr(FORMAT);
@ -653,8 +639,11 @@ OperatorAttrs Conv2DInfo::CreateNeighborExchangeAttrs(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(tensor_type);
auto dtype = tensor_type->element();
MS_EXCEPTION_IF_NULL(dtype);
Attr send_ranks = {SEND_RNAK_IDS, MakeListValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RNAK_IDS, MakeListValue(recv_rank_ids_)};
// the type of send_rank_ids, recv_rank_ids, send_shapes, recv_shapes is list, is not tuple, can not use MakeValue
// the MakeValue(vector) return a tuple
Attr send_ranks = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
Attr recv_ranks = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
Attr send_shapes = {SEND_SHAPES, MakeTupleListValue(send_shapes_)};
Attr recv_shapes = {RECV_SHAPES, MakeTupleListValue(recv_shapes_)};
Attr recv_type = {RECV_TYPE, dtype};

View File

@ -1890,5 +1890,24 @@ std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue) {
auto val = sequeue->cast<ValueListPtr>();
return val->value();
}
ValuePtr MakeListValue(const std::vector<int64_t> &v) {
std::vector<ValuePtr> list;
(void)std::transform(v.begin(), v.end(), std::back_inserter(list), [](int64_t ele) { return MakeValue(ele); });
return std::make_shared<ValueSequeue>(list);
}
ValuePtr MakeTupleListValue(const Shapes &v) {
std::vector<ValuePtr> tuple;
(void)std::transform(v.begin(), v.end(), std::back_inserter(tuple),
[](const std::vector<int64_t> &list) { return MakeListValue(list); });
return std::make_shared<ValueTuple>(tuple);
}
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple) {
auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
auto value_node = NewValueNode(value_ptr);
return value_node->cast<AnfNodePtr>();
}
} // namespace parallel
} // namespace mindspore

View File

@ -343,6 +343,9 @@ Status GenerateStrategiesWithBroadcast(int64_t stage_id, const Shapes &inputs_sh
Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph);
std::vector<ValuePtr> GetValueSequeue(const ValuePtr &sequeue);
ValuePtr MakeListValue(const std::vector<int64_t> &v);
ValuePtr MakeTupleListValue(const Shapes &v);
AnfNodePtr CreateValueTupleAnfNodePtr(const std::vector<int64_t> &value_tuple);
} // namespace parallel
} // namespace mindspore

View File

@ -172,11 +172,17 @@ constexpr char REPLACE[] = "replace";
constexpr char CONNSYMBOL[] = "/";
constexpr char INSTANCE_NAME[] = "instance_name";
constexpr char SPLIT_SENS[] = "split_sens";
constexpr char SEND_RNAK_IDS[] = "send_rank_ids";
constexpr char RECV_RNAK_IDS[] = "recv_rank_ids";
constexpr char SEND_RANK_IDS[] = "send_rank_ids";
constexpr char RECV_RANK_IDS[] = "recv_rank_ids";
constexpr char RECV_SHAPES[] = "recv_shapes";
constexpr char SEND_SHAPES[] = "send_shapes";
constexpr char RECV_TYPE[] = "recv_type";
constexpr char SEND_LENS[] = "send_lens";
constexpr char RECV_LENS[] = "recv_lens";
constexpr char ORI_IMAGE_SIZE[] = "ori_image_size";
constexpr char SPLIT_SIZE[] = "split_size";
constexpr char SRC_START_W[] = "src_start_w";
constexpr char DST_START_W[] = "dst_start_w";
constexpr char SPLIT_TENSOR[] = "split_tensor";
constexpr char DEV_MAT[] = "dev_mat";
constexpr char TENSOR_MAP[] = "tensor_map";
@ -238,6 +244,7 @@ constexpr char SPLIT[] = "Split";
constexpr char ALL_TO_ALL[] = "AlltoAll";
constexpr char NEIGHBOREXCHANGE[] = "NeighborExchange";
constexpr char NEIGHBOREXCHANGEV2[] = "NeighborExchangeV2";
constexpr char PARALLEL_RESIZE_BILINEAR[] = "ParallelResizeBilinear";
constexpr char PERMUTE_BY_AXIS[] = "PermuteByAxis";
constexpr char CONCAT_BY_AXIS[] = "ConcatByAxis";
constexpr char SPLIT_BY_AXIS[] = "SplitByAxis";

View File

@ -20,6 +20,7 @@
#include <memory>
#include <utility>
#include <vector>
#include <cmath>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
@ -28,6 +29,8 @@
namespace mindspore {
namespace parallel {
// ResizeBilinear: support to split N/C/W
// ResizeNearestNeighbor: support to split N/C/H/W if align_corners=False, support to split N/C if align_corners=True
Status ResizeBilinearInfo::GetAttrs() {
size_ = GetTupleIntAttr(SIZE);
if (size_.size() != 2) {
@ -43,8 +46,9 @@ Status ResizeBilinearInfo::GetAttrs() {
Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
need_exchange_overlap_ = false;
if (CheckStrategyValue(strategy, inputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
MS_LOG(ERROR) << name_ << ": Check input strategy failed";
return FAILED;
}
@ -60,8 +64,19 @@ Status ResizeBilinearInfo::CheckStrategy(const StrategyPtr &strategy) {
return FAILED;
}
if (input_strategy[2] != 1 || input_strategy[3] != 1) {
MS_LOG(ERROR) << name_ << ": Do not support split from H or W";
if (input_strategy[2] != 1) {
MS_LOG(ERROR) << name_ << ": Do not support split H dimension";
return FAILED;
}
if (input_strategy[3] != 1) {
need_exchange_overlap_ = true;
MS_LOG(INFO) << name_ << ": Split the w dimension";
}
// check output strategy
if (CheckStrategyValue(strategy, outputs_shape_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Check output strategy failed";
return FAILED;
}
@ -87,6 +102,7 @@ Status ResizeBilinearInfo::InferDevMatrixShape() {
slice_size_ = size_;
slice_size_[0] = slice_size_[0] / dev_matrix_shape_[2];
slice_size_[1] = slice_size_[1] / dev_matrix_shape_[3];
w_dimension_shard_num_ = dev_matrix_shape_[3];
return SUCCESS;
}
@ -119,10 +135,282 @@ std::vector<StrategyPtr> ResizeBilinearInfo::GenerateOpStrategies(int64_t stage_
}
void ResizeBilinearInfo::ReplaceNodeInputOrAttrs() {
// if need exchange overlap, use replace_graph()
if (need_exchange_overlap_) {
return;
}
auto prim = GetValueNode<PrimitivePtr>(cnode_->input(0));
prim->set_attr(SIZE, MakeValue(slice_size_));
}
Status ResizeBilinearInfo::InferRankBias() {
// the origin dev_matrix is [n, c, h, w]
// if repeated calculation
// 1) repeated num in the left of dev matrix, the dev_matrix is [repeated_num, n, c, h, w]
// 2) repeated num in the right of dev matrix, the dev_matrix is [n, c, h, w, repeated_num]
uint64_t w_index_in_dev_matrix = 3;
if (repeated_calc_num_ > 1 && !repeated_num_in_dev_matrix_right_) {
w_index_in_dev_matrix += 1;
}
CheckGlobalDeviceManager();
int64_t rank = g_device_manager->global_rank();
DeviceMatrix dev_matrix(rank, stage_device_list_, dev_matrix_shape_);
RankList group_devices;
if (dev_matrix.GetDevicesAlongDim(w_index_in_dev_matrix, &group_devices) != SUCCESS) {
return FAILED;
}
if (group_devices.size() <= 1) {
MS_LOG(INFO) << name_ << ": The devices' size of w dimension is " << group_devices.size()
<< ", no need to infer rank bias";
return SUCCESS;
}
if (group_devices.size() != LongToSize(w_dimension_shard_num_)) {
MS_LOG(ERROR) << name_ << ": The devices' size of w dimension is " << group_devices.size()
<< ", but the shard num of w dimension is " << w_dimension_shard_num_;
return FAILED;
}
std::vector<int64_t>::iterator it = std::find(group_devices.begin(), group_devices.end(), rank);
if (it == group_devices.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the current rank in device list of w dimension, the current rank is "
<< rank << ", the device list is " << group_devices;
return FAILED;
}
rank_bias_ = std::distance(group_devices.begin(), it);
if (it == group_devices.begin()) {
// the current rank is on the left boundary
left_rank_bias_ = -1;
right_rank_bias_ = rank_bias_ + 1;
left_rank_id_ = -1;
right_rank_id_ = *(it + 1);
} else if (it == group_devices.end() - 1) {
// the current rank is on the right boundary
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = -1;
left_rank_id_ = *(it - 1);
right_rank_id_ = -1;
} else {
// the current rank is middle rank
left_rank_bias_ = rank_bias_ - 1;
right_rank_bias_ = rank_bias_ + 1;
left_rank_id_ = *(it - 1);
right_rank_id_ = *(it + 1);
}
Group group = g_device_manager->CreateGroup(group_devices);
all_to_all_group_ = group.name();
MS_LOG(INFO) << name_ << ": The current rank is " << rank << ", the device list of w dimension is " << group_devices
<< ", the rank bias is " << rank_bias_ << ", the left rank bias is " << left_rank_bias_
<< ", the right rank bias is " << right_rank_bias_ << ", the left rank id is " << left_rank_id_
<< ", the right rank id is " << right_rank_id_ << ", the all to all group is " << all_to_all_group_;
return SUCCESS;
}
void ResizeBilinearInfo::InferScale() {
origin_in_w_shape_ = inputs_shape_[0][3];
origin_out_w_shape_ = outputs_shape_[0][3];
if (origin_out_w_shape_ == 1) {
MS_LOG(EXCEPTION) << name_ << ": Do not support that the w dimension of output shape is 1";
}
if (align_corners_) {
w_scale_ = LongToDouble(origin_in_w_shape_ - 1) / LongToDouble(origin_out_w_shape_ - 1);
} else {
w_scale_ = LongToDouble(origin_in_w_shape_) / LongToDouble(origin_out_w_shape_);
}
MS_LOG(INFO) << name_ << ": The scale is " << w_scale_;
}
int64_t ResizeBilinearInfo::InferOverlapLeftSizeByRankBias(int64_t rank_bias) {
// left_overlap_size = (rank * ori_in_w / w_shard) - floor(scale * rank * slice_w)
int64_t map_left_boundary = DoubleToLong(std::floor(w_scale_ * rank_bias * slice_size_[1]));
int64_t local_left_boundary = rank_bias * origin_in_w_shape_ / w_dimension_shard_num_;
if (map_left_boundary > local_left_boundary) {
MS_LOG(EXCEPTION) << name_ << ": Invalid left overlap, the rank bias is " << rank_bias << ", the map boundary is "
<< map_left_boundary << ", the local boundary is " << local_left_boundary;
}
return local_left_boundary - map_left_boundary;
}
int64_t ResizeBilinearInfo::InferOverlapRightSizeByRankBias(int64_t rank_bias) {
// right_overlap_size = ceil(scale * (rank + 1) * slice_w - 1) - ((rank + 1) * ori_in_w / w_shard - 1)
int64_t map_right_boundary = DoubleToLong(std::ceil(w_scale_ * ((rank_bias + 1) * slice_size_[1] - 1)));
int64_t local_right_boundary = (rank_bias + 1) * origin_in_w_shape_ / w_dimension_shard_num_ - 1;
// need to handle this special condition
if (map_right_boundary > origin_in_w_shape_ - 1) {
map_right_boundary = origin_in_w_shape_ - 1;
}
if (map_right_boundary < local_right_boundary) {
MS_LOG(EXCEPTION) << name_ << ": Invalid right overlap, the rank bias is " << rank_bias << ", the map boundary is "
<< map_right_boundary << ", the local boundary is " << local_right_boundary;
}
return map_right_boundary - local_right_boundary;
}
void ResizeBilinearInfo::InferOverlapSize() {
overlap_left_size_ = InferOverlapLeftSizeByRankBias(rank_bias_);
overlap_right_size_ = InferOverlapRightSizeByRankBias(rank_bias_);
if (rank_bias_ == 0) {
// it has not left rank
left_rank_overlap_right_size_ = 0;
right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_);
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
// it has not right rank
left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = 0;
} else {
// it has left rank and right rank
left_rank_overlap_right_size_ = InferOverlapRightSizeByRankBias(left_rank_bias_);
right_rank_overlap_left_size_ = InferOverlapLeftSizeByRankBias(right_rank_bias_);
}
MS_LOG(INFO) << name_ << ": the left overlap size of current rank is " << overlap_left_size_
<< ", the right overlap size of current rank is " << overlap_right_size_
<< ", the right overlap size of left rank is " << left_rank_overlap_right_size_
<< ", the left overlap size of right rank is " << right_rank_overlap_left_size_;
}
void ResizeBilinearInfo::InferCommunicationAttrs() {
// send rank ids: [-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1]
// recv rank ids: [-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1]
// send lens: [0, 0, send_left_len, send_right_len]
// recv lens: [0, 0, recv_left_len, recv_right_len]
int64_t send_right_rank = -1, send_left_rank = -1, recv_right_rank = -1, recv_left_rank = -1;
int64_t send_left_len = 0, send_right_len = 0, recv_left_len = 0, recv_right_len = 0;
if (rank_bias_ == 0) {
// the first rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
} else if (rank_bias_ == w_dimension_shard_num_ - 1) {
// the last rank
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
} else {
// the middle rank
send_right_len = right_rank_overlap_left_size_;
send_right_rank = send_right_len > 0 ? right_rank_id_ : -1;
recv_right_len = overlap_right_size_;
recv_right_rank = recv_right_len > 0 ? right_rank_id_ : -1;
send_left_len = left_rank_overlap_right_size_;
send_left_rank = send_left_len > 0 ? left_rank_id_ : -1;
recv_left_len = overlap_left_size_;
recv_left_rank = recv_left_len > 0 ? left_rank_id_ : -1;
}
send_rank_ids_ = {-1, -1, send_right_rank, -1, -1, -1, send_left_rank, -1};
recv_rank_ids_ = {-1, -1, recv_right_rank, -1, -1, -1, recv_left_rank, -1};
send_lens_ = {0, 0, send_left_len, send_right_len};
recv_lens_ = {0, 0, recv_left_len, recv_right_len};
MS_LOG(INFO) << name_ << ": The send rank ids is " << send_rank_ids_ << ", the send lens is " << send_lens_
<< ", the recv rank ids is " << recv_rank_ids_ << ", the recv lens is " << recv_lens_;
}
void ResizeBilinearInfo::InferResizeBilinearV2Attrs() {
origin_image_size_ = {inputs_shape_[0][2], inputs_shape_[0][3]};
src_start_w_ = DoubleToLong(std::floor(w_scale_ * rank_bias_ * slice_size_[1]));
dst_start_w_ = rank_bias_ * slice_size_[1];
MS_LOG(INFO) << name_ << ": The origin image size is " << origin_image_size_ << ", src start index is "
<< src_start_w_ << ", dst start index is " << dst_start_w_;
}
void ResizeBilinearInfo::InferNewOperatorAttrs() {
InferCommunicationAttrs();
InferResizeBilinearV2Attrs();
}
OperatorAttrs ResizeBilinearInfo::CreateNeighborExchangeV2Attrs() {
// the type of send_rank_ids, recv_rank_ids, send_lens, recv_lens is list, is not tuple, can not use MakeValue
// the MakeValue(vector) return a tuple
Attr send_rank_ids = {SEND_RANK_IDS, MakeListValue(send_rank_ids_)};
Attr send_lens = {SEND_LENS, MakeListValue(send_lens_)};
Attr recv_rank_ids = {RECV_RANK_IDS, MakeListValue(recv_rank_ids_)};
Attr recv_lens = {RECV_LENS, MakeListValue(recv_lens_)};
Attr data_format = {DATA_FORMAT, MakeValue(NCHW)};
Attr group = {GROUP, MakeValue(all_to_all_group_)};
OperatorAttrs attrs = {send_rank_ids, send_lens, recv_rank_ids, recv_lens, data_format, group};
return attrs;
}
OperatorAttrs ResizeBilinearInfo::CreateParallelResizeBilinearAttrs() {
Attr ori_image_size = {ORI_IMAGE_SIZE, MakeValue(origin_image_size_)};
Attr split_size = {SPLIT_SIZE, MakeValue(slice_size_)};
Attr src_start_w = {SRC_START_W, MakeValue(src_start_w_)};
Attr dst_start_w = {DST_START_W, MakeValue(dst_start_w_)};
Attr align_corners = {ALIGN_CORNERS, MakeValue(align_corners_)};
OperatorAttrs attrs = {ori_image_size, split_size, src_start_w, dst_start_w, align_corners};
return attrs;
}
void ResizeBilinearInfo::InferReplaceGraph(const CNodePtr &cnode) {
auto graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(graph);
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": Init generator graph failed";
}
auto neighbor_exchange_v2_attrs = CreateNeighborExchangeV2Attrs();
auto neighbor_exchange_v2_node =
gen_g.PushBack({gen_g.NewOpInst(NEIGHBOREXCHANGEV2, neighbor_exchange_v2_attrs), gen_g.virtual_input_node()});
auto size = CreateValueTupleAnfNodePtr(size_);
auto parallel_resize_bilinear_attrs = CreateParallelResizeBilinearAttrs();
auto parallel_resize_bilinear_node = gen_g.PushBack(
{gen_g.NewOpInst(PARALLEL_RESIZE_BILINEAR, parallel_resize_bilinear_attrs), neighbor_exchange_v2_node, size});
std::vector<std::pair<AnfNodePtr, int64_t>> input_nodes = {std::make_pair(neighbor_exchange_v2_node, 1)};
replace_graph_ = std::make_shared<std::pair<std::vector<std::pair<AnfNodePtr, int64_t>>, AnfNodePtr>>(
std::make_pair(input_nodes, parallel_resize_bilinear_node));
}
ReplaceGraphPtr ResizeBilinearInfo::replace_graph(const CNodePtr &cnode) {
if (!need_exchange_overlap_) {
return nullptr;
}
if (InferRankBias() != SUCCESS) {
MS_LOG(EXCEPTION) << name_ << ": infer rank bias failed";
}
InferScale();
InferOverlapSize();
InferNewOperatorAttrs();
InferReplaceGraph(cnode);
return replace_graph_;
}
Status ResizeNearestNeighborInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);

View File

@ -23,6 +23,7 @@
#include "utils/hash_map.h"
#include "ir/value.h"
#include "frontend/parallel/graph_util/generate_graph.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
@ -46,10 +47,66 @@ class ResizeBilinearInfo : public OperatorInfo {
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
void ReplaceNodeInputOrAttrs() override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::vector<int64_t> size_;
std::vector<int64_t> slice_size_;
bool align_corners_ = false;
bool need_exchange_overlap_ = false;
private:
Status InferRankBias();
void InferOverlapSize();
void InferScale();
void InferNewOperatorAttrs();
void InferCommunicationAttrs();
void InferResizeBilinearV2Attrs();
void InferReplaceGraph(const CNodePtr &cnode);
int64_t InferOverlapLeftSizeByRankBias(int64_t rank_bias);
int64_t InferOverlapRightSizeByRankBias(int64_t rank_bias);
OperatorAttrs CreateNeighborExchangeV2Attrs();
OperatorAttrs CreateParallelResizeBilinearAttrs();
// rank_bias_ is the position of the current rank in the w dimension of the dev_matrix(have not split h dimension)
int64_t rank_bias_ = 0;
int64_t left_rank_bias_ = -1;
int64_t right_rank_bias_ = -1;
int64_t left_rank_id_ = -1;
int64_t right_rank_id_ = -1;
int64_t overlap_left_size_ = 0;
int64_t overlap_right_size_ = 0;
int64_t left_rank_overlap_right_size_ = 0;
int64_t right_rank_overlap_left_size_ = 0;
int64_t origin_in_w_shape_ = 1;
int64_t origin_out_w_shape_ = 1;
int64_t w_dimension_shard_num_ = 1;
// the send_rank_ids_ or recv_rank_ids is an array with 8 rank ids, the order of index in the array is organized in
// the following format(the 'R' is current rank)
// +++++++++++++
// | 7 | 0 | 1 |
// +++++++++++++
// | 6 | R | 2 |
// +++++++++++++
// | 5 | 4 | 3 |
// +++++++++++++
std::vector<int64_t> send_rank_ids_; // 8 rank ids
std::vector<int64_t> recv_rank_ids_; // 8 rank ids
// the send_lens_ or recv_lens_ is an array with 4 lens, the order in the array represents top, bottom, left, right
std::vector<int64_t> send_lens_; // [top, bottom, left, right]
std::vector<int64_t> recv_lens_; // [top, bottom, left, right]
std::string all_to_all_group_;
std::vector<int64_t> origin_image_size_; // [H, W]
int64_t src_start_w_ = 0;
int64_t dst_start_w_ = 0;
double w_scale_ = 1.0; // the scale in w dimension, now only support to split w dimension
};
class ResizeNearestNeighborInfo : public ResizeBilinearInfo {

View File

@ -184,12 +184,6 @@ ReplaceGraphPtr SliceInfo::replace_graph(const CNodePtr &cnode) {
return replace_graph_;
}
AnfNodePtr CreateValueTupleAndNodePtr(const std::vector<int64_t> &value_tuple) {
auto value_ptr = MakeValue(value_tuple)->cast<ValueTuplePtr>();
auto value_node = NewValueNode(value_ptr);
return value_node->cast<AnfNodePtr>();
}
Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
GenerateGraph gen_g = GenerateGraph(attrs_);
if (gen_g.Init(cnode) != SUCCESS) {
@ -207,8 +201,8 @@ Status SliceInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
sliced_size_shape_int.push_back(input_slice_shape[i]);
}
}
auto new_begin = CreateValueTupleAndNodePtr(begin_);
auto new_size = CreateValueTupleAndNodePtr(sliced_size_shape_int);
auto new_begin = CreateValueTupleAnfNodePtr(begin_);
auto new_size = CreateValueTupleAnfNodePtr(sliced_size_shape_int);
auto slice = gen_g.PushBack({gen_g.NewOpInst(SLICE), gen_g.virtual_input_node(), new_begin, new_size});