StridedSlice

This commit is contained in:
shen_jingxing 2021-06-18 19:05:44 +08:00
parent 70152adcb3
commit 0487f755bb
4 changed files with 244 additions and 187 deletions

View File

@ -30,6 +30,7 @@
#include "ops/neg.h"
#include "ops/mul.h"
#include "ops/sub.h"
#include "ops/strided_slice.h"
#include "abstract/abstract_function.h"
#include "abstract/infer_functions.h"
#include "ops/tile.h"
@ -217,6 +218,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
{prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
{prim::kPrimShape, {InferImplShape, nullptr, false}},
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
{prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}},
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
{prim::kPrimConcat, {InferImplConcat, nullptr, true}},
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},

View File

@ -74,6 +74,7 @@ constexpr auto kReLUGradV2 = "ReluGradV2";
constexpr auto kGeLUGrad = "GeLUGrad";
constexpr auto kFastGeLU = "FastGeLU";
constexpr auto kFastGeLUGrad = "FastGeLUGrad";
constexpr auto kStridedSlice = "StridedSlice";
constexpr auto kZerosLike = "ZerosLike";
constexpr auto kOnesLike = "OnesLike";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
@ -165,7 +166,7 @@ inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND"
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense");
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice");
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>(kStridedSlice);
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,7 +28,233 @@
namespace mindspore {
namespace ops {
namespace {
std::vector<int64_t> TenToTwo(int64_t num) {
void EllipsisInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
bool has_ellipsis) {
if (!has_ellipsis) {
return;
}
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
size_t x_rank = x_shape.size();
size_t slice_len = begin_v.size();
std::vector<int64_t> begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask());
std::vector<int64_t> end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask());
std::vector<int64_t> new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask());
std::vector<int64_t> shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask());
int64_t num = 0;
for (size_t n = j + 1; n < slice_len; n++) {
if (new_axis_pos[n] == 1) {
num++;
}
}
int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
infer_shape->insert(infer_shape->end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims);
j += 1;
i += ellipsis_occupied_dims;
while (i < x_rank || j < slice_len) {
int64_t x_dim_size = x_shape[i];
int64_t start = begin_v[j];
int64_t finish = end_v[j];
int64_t strides = strides_v[j];
if (j < begin_pos.size() || j < slice_len) {
start = strides_v[j] < 0 ? -1 : 0;
}
if (j < end_pos.size() && end_pos[j] == 1) {
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
}
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
infer_shape->push_back(1);
j += 1;
continue;
}
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
}
j += 1;
i += 1;
continue;
}
int64_t slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
infer_shape->push_back(slicing_length);
i += 1;
j += 1;
}
return;
}
const std::vector<int64_t> CheckAndGetValidStrides(const AbstractBasePtr &stride_arg) {
MS_EXCEPTION_IF_NULL(stride_arg);
auto temp_strides = stride_arg->cast<abstract::AbstractTuplePtr>()->BuildValue();
MS_EXCEPTION_IF_NULL(temp_strides);
auto strides = GetValue<std::vector<int64_t>>(temp_strides);
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride == 0; })) {
MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
}
return strides;
}
std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
const std::vector<int64_t> &end_v, const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &strides_v, const std::vector<int64_t> &begin_pos,
const std::vector<int64_t> &shrink_axis_pos, const std::vector<int64_t> &end_pos,
const std::vector<int64_t> &new_axis_pos,
const std::vector<int64_t> &ellipsis_pos) {
size_t i = 0;
size_t j = 0;
int64_t start;
int64_t finish;
int64_t strides;
int64_t slicing_length;
bool has_ellipsis = false;
std::vector<int64_t> infer_shape;
size_t slice_len = begin_v.size();
size_t x_rank = x_shape.size();
while (i < x_rank || j < slice_len) {
int64_t x_dim_size = x_shape[i];
if (j < slice_len) {
start = begin_v[j];
finish = end_v[j];
strides = strides_v[j];
if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
has_ellipsis = true;
break;
}
if (j < begin_pos.size() && begin_pos[j] == 1) {
start = strides_v[j] < 0 ? -1 : 0;
}
if (j < end_pos.size() && end_pos[j] == 1) {
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
}
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
infer_shape.push_back(1);
j += 1;
continue;
}
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
}
j += 1;
i += 1;
continue;
}
} else {
start = 0;
finish = x_shape[0];
strides = 1;
}
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
MS_EXCEPTION_IF_NULL(strided_slice_prim);
slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
infer_shape.push_back(slicing_length);
i += 1;
j += 1;
}
EllipsisInferShape(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
return infer_shape;
}
abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
MS_EXCEPTION_IF_NULL(strided_slice_prim);
auto temp_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
auto temp_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
auto strides_v = CheckAndGetValidStrides(input_args[3]);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
std::vector<int64_t> begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask());
std::vector<int64_t> end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask());
std::vector<int64_t> ellipsis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_ellipsis_mask());
std::vector<int64_t> new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask());
std::vector<int64_t> shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask());
auto ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, x_shape, strides_v, begin_pos, shrink_axis_pos,
end_pos, new_axis_pos, ellipsis_pos);
if (min_shape.empty() || max_shape.empty()) {
return std::make_shared<abstract::Shape>(ret_in_shape);
}
auto ret_min_shape = ComputeInferShape(primitive, begin_v, end_v, min_shape, strides_v, begin_pos, shrink_axis_pos,
end_pos, new_axis_pos, ellipsis_pos);
auto ret_max_shape = ComputeInferShape(primitive, begin_v, end_v, max_shape, strides_v, begin_pos, shrink_axis_pos,
end_pos, new_axis_pos, ellipsis_pos);
return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
}
TypePtr StridedSliceInferType(const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return infer_type;
}
} // namespace
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kBeginMask, MakeValue(begin_mask));
}
int64_t StridedSlice::get_begin_mask() const {
auto value_ptr = GetAttr(kBeginMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_end_mask(const int64_t end_mask) {
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kEndMask, MakeValue(end_mask));
}
int64_t StridedSlice::get_end_mask() const {
auto value_ptr = GetAttr(kEndMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
std::ostringstream buffer;
if (bs.count() > 1) {
buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
MS_EXCEPTION(ValueError) << buffer.str();
}
this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
}
int64_t StridedSlice::get_ellipsis_mask() const {
auto value_ptr = GetAttr(kEllipsisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
}
int64_t StridedSlice::get_new_axis_mask() const {
auto value_ptr = GetAttr(kNewAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
}
int64_t StridedSlice::get_shrink_axis_mask() const {
auto value_ptr = GetAttr(kShrinkAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);
this->set_end_mask(end_mask);
this->set_ellipsis_mask(ellipsis_mask);
this->set_new_axis_mask(new_axis_mask);
this->set_shrink_axis_mask(shrink_axis_mask);
}
std::vector<int64_t> StridedSlice::TenToTwo(int64_t num) {
std::vector<int64_t> output;
if (num == 0) {
output.push_back(0);
@ -42,13 +268,7 @@ std::vector<int64_t> TenToTwo(int64_t num) {
return output;
}
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector<int64_t> x_shape,
int64_t i) {
if (i > (int64_t)x_shape.size()) {
MS_EXCEPTION(ValueError) << "For 'StridedSlice', When their is no new axis, "
"the index length must be less or equal than the dim of x.";
}
int64_t x_dim = x_shape[i];
int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const {
int64_t slicing_length = 0;
if (strides > 0) {
if ((start_pos >= x_dim) || end_pos < -x_dim) {
@ -97,183 +317,13 @@ int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strid
}
return slicing_length;
}
abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
auto temp_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
auto temp_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue();
auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
int64_t x_rank = x_shape.size();
int64_t slice_len = begin_v.size();
std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
std::vector<int64_t> ellipsis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEllipsisMask)));
std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
int64_t i = 0;
int64_t j = 0;
int64_t start;
int64_t finish;
int64_t strides;
int64_t slicing_length;
bool has_ellipsis = false;
std::vector<int64_t> infer_shape;
while (i < x_rank || j < slice_len) {
if (j < slice_len) {
start = begin_v[j];
finish = end_v[j];
strides = strides_v[j];
if (j < (int64_t)ellipsis_pos.size() && ellipsis_pos[j] == 1) {
has_ellipsis = true;
break;
}
if (j < (int64_t)begin_pos.size() && begin_pos[j] == 1) {
start = strides_v[j] < 0 ? -1 : 0;
}
if (j < (int64_t)end_pos.size() && end_pos[j] == 1) {
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
}
if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) {
infer_shape.push_back(1);
j += 1;
continue;
}
if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) {
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
}
j += 1;
i += 1;
continue;
}
} else {
start = 0;
finish = x_shape[0];
strides = 1;
}
slicing_length = compute_slicing_length(start, finish, strides, x_shape, i);
infer_shape.push_back(slicing_length);
i += 1;
j += 1;
}
int64_t num = 0;
for (int64_t n = j + 1; n < slice_len; n++) {
if (new_axis_pos[n] == 1) {
num++;
}
}
if (has_ellipsis) {
int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
infer_shape.insert(infer_shape.end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims);
j += 1;
i += ellipsis_occupied_dims;
while (i < x_rank || j < slice_len) {
start = begin_v[j];
finish = end_v[j];
strides = strides_v[j];
if (j < (int64_t)begin_pos.size() || j < slice_len) {
start = strides_v[j] < 0 ? -1 : 0;
}
if (j < (int64_t)end_pos.size() && end_pos[j] == 1) {
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
}
if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) {
infer_shape.push_back(1);
j += 1;
continue;
}
if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) {
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
}
j += 1;
i += 1;
continue;
}
slicing_length = compute_slicing_length(start, finish, strides, x_shape, i);
infer_shape.push_back(slicing_length);
i += 1;
j += 1;
}
}
return std::make_shared<abstract::Shape>(infer_shape);
}
TypePtr StridedSliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
return infer_type;
}
} // namespace
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kBeginMask, MakeValue(begin_mask));
}
int64_t StridedSlice::get_begin_mask() const { return GetValue<int64_t>(GetAttr(kBeginMask)); }
void StridedSlice::set_end_mask(const int64_t end_mask) {
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kEndMask, MakeValue(end_mask));
}
int64_t StridedSlice::get_end_mask() const {
auto value_ptr = GetAttr(kEndMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
std::ostringstream buffer;
if (bs.count() > 1) {
buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
MS_EXCEPTION(ValueError) << buffer.str();
}
this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
}
int64_t StridedSlice::get_ellipsis_mask() const {
auto value_ptr = GetAttr(kEllipsisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
}
int64_t StridedSlice::get_new_axis_mask() const {
auto value_ptr = GetAttr(kNewAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
}
int64_t StridedSlice::get_shrink_axis_mask() const {
auto value_ptr = GetAttr(kShrinkAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);
this->set_end_mask(end_mask);
this->set_ellipsis_mask(ellipsis_mask);
this->set_new_axis_mask(new_axis_mask);
this->set_shrink_axis_mask(shrink_axis_mask);
}
AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(primitive, input_args),
return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(input_args),
StridedSliceInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_C(kNameStridedSlice, StridedSlice);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -26,10 +26,12 @@
namespace mindspore {
namespace ops {
constexpr auto kNameStridedSlice = "StridedSlice";
constexpr auto kNameStridedSlice = prim::kStridedSlice;
class StridedSlice : public PrimitiveC {
public:
StridedSlice() : PrimitiveC(kNameStridedSlice) { InitIOName({"x", "begin", "end", "strides"}, {"output"}); }
StridedSlice() : PrimitiveC(prim::kPrimStridedSlice->name()) {
InitIOName({"x", "begin", "end", "strides"}, {"output"});
}
~StridedSlice() = default;
MS_DECLARE_PARENT(StridedSlice, PrimitiveC);
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
@ -45,8 +47,10 @@ class StridedSlice : public PrimitiveC {
int64_t get_new_axis_mask() const;
int64_t get_shrink_axis_mask() const;
std::vector<int64_t> TenToTwo(int64_t num);
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector<int64_t> x_shape,
int64_t i);
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const;
};
struct ComputeHasEllipsis {
bool has_ellipsis;
};
AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);