forked from mindspore-Ecosystem/mindspore
StridedSlice
This commit is contained in:
parent
70152adcb3
commit
0487f755bb
|
@ -30,6 +30,7 @@
|
|||
#include "ops/neg.h"
|
||||
#include "ops/mul.h"
|
||||
#include "ops/sub.h"
|
||||
#include "ops/strided_slice.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "abstract/infer_functions.h"
|
||||
#include "ops/tile.h"
|
||||
|
@ -217,6 +218,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() {
|
|||
{prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}},
|
||||
{prim::kPrimShape, {InferImplShape, nullptr, false}},
|
||||
{prim::kPrimTranspose, {InferImplTranspose, nullptr, true}},
|
||||
{prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}},
|
||||
{prim::kPrimReshape, {InferImplReshape, nullptr, true}},
|
||||
{prim::kPrimConcat, {InferImplConcat, nullptr, true}},
|
||||
{prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}},
|
||||
|
|
|
@ -74,6 +74,7 @@ constexpr auto kReLUGradV2 = "ReluGradV2";
|
|||
constexpr auto kGeLUGrad = "GeLUGrad";
|
||||
constexpr auto kFastGeLU = "FastGeLU";
|
||||
constexpr auto kFastGeLUGrad = "FastGeLUGrad";
|
||||
constexpr auto kStridedSlice = "StridedSlice";
|
||||
constexpr auto kZerosLike = "ZerosLike";
|
||||
constexpr auto kOnesLike = "OnesLike";
|
||||
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
|
||||
|
@ -165,7 +166,7 @@ inline const PrimitivePtr kPrimGatherND = std::make_shared<Primitive>("GatherND"
|
|||
inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared<Primitive>("SparseGatherV2");
|
||||
inline const PrimitivePtr kPrimSparseToDense = std::make_shared<Primitive>("SparseToDense");
|
||||
inline const PrimitivePtr kPrimShape = std::make_shared<Primitive>("Shape");
|
||||
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>("StridedSlice");
|
||||
inline const PrimitivePtr kPrimStridedSlice = std::make_shared<Primitive>(kStridedSlice);
|
||||
inline const PrimitivePtr kPrimDynamicShape = std::make_shared<Primitive>("DynamicShape");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared<Primitive>("EmbeddingLookup");
|
||||
inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared<Primitive>("EmbeddingLookupCommGrad");
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,7 +28,233 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
std::vector<int64_t> TenToTwo(int64_t num) {
|
||||
void EllipsisInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &x_shape,
|
||||
const std::vector<int64_t> &begin_v, const std::vector<int64_t> &end_v,
|
||||
const std::vector<int64_t> &strides_v, std::vector<int64_t> *infer_shape, size_t i, size_t j,
|
||||
bool has_ellipsis) {
|
||||
if (!has_ellipsis) {
|
||||
return;
|
||||
}
|
||||
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
|
||||
size_t x_rank = x_shape.size();
|
||||
size_t slice_len = begin_v.size();
|
||||
std::vector<int64_t> begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask());
|
||||
std::vector<int64_t> end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask());
|
||||
std::vector<int64_t> new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask());
|
||||
std::vector<int64_t> shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask());
|
||||
|
||||
int64_t num = 0;
|
||||
for (size_t n = j + 1; n < slice_len; n++) {
|
||||
if (new_axis_pos[n] == 1) {
|
||||
num++;
|
||||
}
|
||||
}
|
||||
|
||||
int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
|
||||
infer_shape->insert(infer_shape->end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims);
|
||||
j += 1;
|
||||
i += ellipsis_occupied_dims;
|
||||
|
||||
while (i < x_rank || j < slice_len) {
|
||||
int64_t x_dim_size = x_shape[i];
|
||||
int64_t start = begin_v[j];
|
||||
int64_t finish = end_v[j];
|
||||
int64_t strides = strides_v[j];
|
||||
if (j < begin_pos.size() || j < slice_len) {
|
||||
start = strides_v[j] < 0 ? -1 : 0;
|
||||
}
|
||||
if (j < end_pos.size() && end_pos[j] == 1) {
|
||||
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
|
||||
}
|
||||
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
|
||||
infer_shape->push_back(1);
|
||||
j += 1;
|
||||
continue;
|
||||
}
|
||||
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
|
||||
if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
|
||||
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
|
||||
}
|
||||
j += 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
int64_t slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
|
||||
infer_shape->push_back(slicing_length);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const std::vector<int64_t> CheckAndGetValidStrides(const AbstractBasePtr &stride_arg) {
|
||||
MS_EXCEPTION_IF_NULL(stride_arg);
|
||||
auto temp_strides = stride_arg->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(temp_strides);
|
||||
auto strides = GetValue<std::vector<int64_t>>(temp_strides);
|
||||
if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride == 0; })) {
|
||||
MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0.";
|
||||
}
|
||||
return strides;
|
||||
}
|
||||
|
||||
std::vector<int64_t> ComputeInferShape(const PrimitivePtr &primitive, const std::vector<int64_t> &begin_v,
|
||||
const std::vector<int64_t> &end_v, const std::vector<int64_t> &x_shape,
|
||||
const std::vector<int64_t> &strides_v, const std::vector<int64_t> &begin_pos,
|
||||
const std::vector<int64_t> &shrink_axis_pos, const std::vector<int64_t> &end_pos,
|
||||
const std::vector<int64_t> &new_axis_pos,
|
||||
const std::vector<int64_t> &ellipsis_pos) {
|
||||
size_t i = 0;
|
||||
size_t j = 0;
|
||||
int64_t start;
|
||||
int64_t finish;
|
||||
int64_t strides;
|
||||
int64_t slicing_length;
|
||||
bool has_ellipsis = false;
|
||||
std::vector<int64_t> infer_shape;
|
||||
size_t slice_len = begin_v.size();
|
||||
size_t x_rank = x_shape.size();
|
||||
while (i < x_rank || j < slice_len) {
|
||||
int64_t x_dim_size = x_shape[i];
|
||||
if (j < slice_len) {
|
||||
start = begin_v[j];
|
||||
finish = end_v[j];
|
||||
strides = strides_v[j];
|
||||
if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) {
|
||||
has_ellipsis = true;
|
||||
break;
|
||||
}
|
||||
if (j < begin_pos.size() && begin_pos[j] == 1) {
|
||||
start = strides_v[j] < 0 ? -1 : 0;
|
||||
}
|
||||
if (j < end_pos.size() && end_pos[j] == 1) {
|
||||
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
|
||||
}
|
||||
if (j < new_axis_pos.size() && new_axis_pos[j] == 1) {
|
||||
infer_shape.push_back(1);
|
||||
j += 1;
|
||||
continue;
|
||||
}
|
||||
if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
|
||||
if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) {
|
||||
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
|
||||
}
|
||||
j += 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
start = 0;
|
||||
finish = x_shape[0];
|
||||
strides = 1;
|
||||
}
|
||||
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_prim);
|
||||
slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size);
|
||||
infer_shape.push_back(slicing_length);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
EllipsisInferShape(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis);
|
||||
return infer_shape;
|
||||
}
|
||||
|
||||
abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto strided_slice_prim = primitive->cast<PrimStridedSlicePtr>();
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_prim);
|
||||
auto temp_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
|
||||
auto temp_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
|
||||
auto strides_v = CheckAndGetValidStrides(input_args[3]);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
|
||||
std::vector<int64_t> begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask());
|
||||
std::vector<int64_t> end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask());
|
||||
std::vector<int64_t> ellipsis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_ellipsis_mask());
|
||||
std::vector<int64_t> new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask());
|
||||
std::vector<int64_t> shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask());
|
||||
auto ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, x_shape, strides_v, begin_pos, shrink_axis_pos,
|
||||
end_pos, new_axis_pos, ellipsis_pos);
|
||||
if (min_shape.empty() || max_shape.empty()) {
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape);
|
||||
}
|
||||
auto ret_min_shape = ComputeInferShape(primitive, begin_v, end_v, min_shape, strides_v, begin_pos, shrink_axis_pos,
|
||||
end_pos, new_axis_pos, ellipsis_pos);
|
||||
auto ret_max_shape = ComputeInferShape(primitive, begin_v, end_v, max_shape, strides_v, begin_pos, shrink_axis_pos,
|
||||
end_pos, new_axis_pos, ellipsis_pos);
|
||||
return std::make_shared<abstract::Shape>(ret_in_shape, ret_min_shape, ret_max_shape);
|
||||
}
|
||||
|
||||
TypePtr StridedSliceInferType(const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_begin_mask() const {
|
||||
auto value_ptr = GetAttr(kBeginMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_end_mask(const int64_t end_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_end_mask() const {
|
||||
auto value_ptr = GetAttr(kEndMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||
std::ostringstream buffer;
|
||||
if (bs.count() > 1) {
|
||||
buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_ellipsis_mask() const {
|
||||
auto value_ptr = GetAttr(kEllipsisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_new_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kNewAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_shrink_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
this->set_end_mask(end_mask);
|
||||
this->set_ellipsis_mask(ellipsis_mask);
|
||||
this->set_new_axis_mask(new_axis_mask);
|
||||
this->set_shrink_axis_mask(shrink_axis_mask);
|
||||
}
|
||||
|
||||
std::vector<int64_t> StridedSlice::TenToTwo(int64_t num) {
|
||||
std::vector<int64_t> output;
|
||||
if (num == 0) {
|
||||
output.push_back(0);
|
||||
|
@ -42,13 +268,7 @@ std::vector<int64_t> TenToTwo(int64_t num) {
|
|||
return output;
|
||||
}
|
||||
|
||||
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector<int64_t> x_shape,
|
||||
int64_t i) {
|
||||
if (i > (int64_t)x_shape.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For 'StridedSlice', When their is no new axis, "
|
||||
"the index length must be less or equal than the dim of x.";
|
||||
}
|
||||
int64_t x_dim = x_shape[i];
|
||||
int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const {
|
||||
int64_t slicing_length = 0;
|
||||
if (strides > 0) {
|
||||
if ((start_pos >= x_dim) || end_pos < -x_dim) {
|
||||
|
@ -97,183 +317,13 @@ int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strid
|
|||
}
|
||||
return slicing_length;
|
||||
}
|
||||
abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
auto temp_begin_v = input_args[1]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto begin_v = GetValue<std::vector<int64_t>>(temp_begin_v);
|
||||
auto temp_end_v = input_args[2]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto end_v = GetValue<std::vector<int64_t>>(temp_end_v);
|
||||
auto temp_strides_v = input_args[3]->cast<abstract::AbstractTuplePtr>()->BuildValue();
|
||||
auto strides_v = GetValue<std::vector<int64_t>>(temp_strides_v);
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
int64_t x_rank = x_shape.size();
|
||||
int64_t slice_len = begin_v.size();
|
||||
std::vector<int64_t> begin_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kBeginMask)));
|
||||
std::vector<int64_t> end_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEndMask)));
|
||||
std::vector<int64_t> ellipsis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kEllipsisMask)));
|
||||
std::vector<int64_t> new_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kNewAxisMask)));
|
||||
std::vector<int64_t> shrink_axis_pos = TenToTwo(GetValue<int64_t>(primitive->GetAttr(kShrinkAxisMask)));
|
||||
|
||||
int64_t i = 0;
|
||||
int64_t j = 0;
|
||||
int64_t start;
|
||||
int64_t finish;
|
||||
int64_t strides;
|
||||
int64_t slicing_length;
|
||||
bool has_ellipsis = false;
|
||||
std::vector<int64_t> infer_shape;
|
||||
while (i < x_rank || j < slice_len) {
|
||||
if (j < slice_len) {
|
||||
start = begin_v[j];
|
||||
finish = end_v[j];
|
||||
strides = strides_v[j];
|
||||
if (j < (int64_t)ellipsis_pos.size() && ellipsis_pos[j] == 1) {
|
||||
has_ellipsis = true;
|
||||
break;
|
||||
}
|
||||
if (j < (int64_t)begin_pos.size() && begin_pos[j] == 1) {
|
||||
start = strides_v[j] < 0 ? -1 : 0;
|
||||
}
|
||||
if (j < (int64_t)end_pos.size() && end_pos[j] == 1) {
|
||||
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
|
||||
}
|
||||
if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) {
|
||||
infer_shape.push_back(1);
|
||||
j += 1;
|
||||
continue;
|
||||
}
|
||||
if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
|
||||
if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) {
|
||||
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
|
||||
}
|
||||
j += 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
start = 0;
|
||||
finish = x_shape[0];
|
||||
strides = 1;
|
||||
}
|
||||
slicing_length = compute_slicing_length(start, finish, strides, x_shape, i);
|
||||
infer_shape.push_back(slicing_length);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
|
||||
int64_t num = 0;
|
||||
for (int64_t n = j + 1; n < slice_len; n++) {
|
||||
if (new_axis_pos[n] == 1) {
|
||||
num++;
|
||||
}
|
||||
}
|
||||
if (has_ellipsis) {
|
||||
int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num;
|
||||
infer_shape.insert(infer_shape.end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims);
|
||||
j += 1;
|
||||
i += ellipsis_occupied_dims;
|
||||
|
||||
while (i < x_rank || j < slice_len) {
|
||||
start = begin_v[j];
|
||||
finish = end_v[j];
|
||||
strides = strides_v[j];
|
||||
if (j < (int64_t)begin_pos.size() || j < slice_len) {
|
||||
start = strides_v[j] < 0 ? -1 : 0;
|
||||
}
|
||||
if (j < (int64_t)end_pos.size() && end_pos[j] == 1) {
|
||||
finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i];
|
||||
}
|
||||
if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) {
|
||||
infer_shape.push_back(1);
|
||||
j += 1;
|
||||
continue;
|
||||
}
|
||||
if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) {
|
||||
if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) {
|
||||
MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number";
|
||||
}
|
||||
j += 1;
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
slicing_length = compute_slicing_length(start, finish, strides, x_shape, i);
|
||||
infer_shape.push_back(slicing_length);
|
||||
i += 1;
|
||||
j += 1;
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
}
|
||||
|
||||
TypePtr StridedSliceInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto infer_type = input_args[0]->BuildType()->cast<TensorTypePtr>()->element();
|
||||
return infer_type;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kBeginMask, MakeValue(begin_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_begin_mask() const { return GetValue<int64_t>(GetAttr(kBeginMask)); }
|
||||
void StridedSlice::set_end_mask(const int64_t end_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kEndMask, MakeValue(end_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_end_mask() const {
|
||||
auto value_ptr = GetAttr(kEndMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
|
||||
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
|
||||
std::ostringstream buffer;
|
||||
if (bs.count() > 1) {
|
||||
buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask();
|
||||
MS_EXCEPTION(ValueError) << buffer.str();
|
||||
}
|
||||
this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_ellipsis_mask() const {
|
||||
auto value_ptr = GetAttr(kEllipsisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_new_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kNewAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
|
||||
CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
|
||||
this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
|
||||
}
|
||||
int64_t StridedSlice::get_shrink_axis_mask() const {
|
||||
auto value_ptr = GetAttr(kShrinkAxisMask);
|
||||
return GetValue<int64_t>(value_ptr);
|
||||
}
|
||||
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
|
||||
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
|
||||
this->set_begin_mask(begin_mask);
|
||||
this->set_end_mask(end_mask);
|
||||
this->set_ellipsis_mask(ellipsis_mask);
|
||||
this->set_new_axis_mask(new_axis_mask);
|
||||
this->set_shrink_axis_mask(shrink_axis_mask);
|
||||
}
|
||||
|
||||
AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(primitive, input_args),
|
||||
return std::make_shared<abstract::AbstractTensor>(StridedSliceInferType(input_args),
|
||||
StridedSliceInferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameStridedSlice, StridedSlice);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,10 +26,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameStridedSlice = "StridedSlice";
|
||||
constexpr auto kNameStridedSlice = prim::kStridedSlice;
|
||||
class StridedSlice : public PrimitiveC {
|
||||
public:
|
||||
StridedSlice() : PrimitiveC(kNameStridedSlice) { InitIOName({"x", "begin", "end", "strides"}, {"output"}); }
|
||||
StridedSlice() : PrimitiveC(prim::kPrimStridedSlice->name()) {
|
||||
InitIOName({"x", "begin", "end", "strides"}, {"output"});
|
||||
}
|
||||
~StridedSlice() = default;
|
||||
MS_DECLARE_PARENT(StridedSlice, PrimitiveC);
|
||||
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
|
||||
|
@ -45,8 +47,10 @@ class StridedSlice : public PrimitiveC {
|
|||
int64_t get_new_axis_mask() const;
|
||||
int64_t get_shrink_axis_mask() const;
|
||||
std::vector<int64_t> TenToTwo(int64_t num);
|
||||
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector<int64_t> x_shape,
|
||||
int64_t i);
|
||||
int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const;
|
||||
};
|
||||
struct ComputeHasEllipsis {
|
||||
bool has_ellipsis;
|
||||
};
|
||||
AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
|
|
Loading…
Reference in New Issue