From 0487f755bb624e51781da8d2090e01df139f0ab9 Mon Sep 17 00:00:00 2001 From: shen_jingxing Date: Fri, 18 Jun 2021 19:05:44 +0800 Subject: [PATCH] StridedSlice --- .../core/abstract/primitive_infer_map.cc | 2 + mindspore/core/base/core_ops.h | 3 +- mindspore/core/ops/strided_slice.cc | 412 ++++++++++-------- mindspore/core/ops/strided_slice.h | 14 +- 4 files changed, 244 insertions(+), 187 deletions(-) diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 3616372dd30..0986c497c7d 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -30,6 +30,7 @@ #include "ops/neg.h" #include "ops/mul.h" #include "ops/sub.h" +#include "ops/strided_slice.h" #include "abstract/abstract_function.h" #include "abstract/infer_functions.h" #include "ops/tile.h" @@ -217,6 +218,7 @@ PrimitiveEvalImplMap &GetPrimitiveToBackendEvalImplMap() { {prim::kPrimRealDiv, {ops::RealDivInfer, nullptr, false}}, {prim::kPrimShape, {InferImplShape, nullptr, false}}, {prim::kPrimTranspose, {InferImplTranspose, nullptr, true}}, + {prim::kPrimStridedSlice, {ops::StridedSliceInfer, nullptr, true}}, {prim::kPrimReshape, {InferImplReshape, nullptr, true}}, {prim::kPrimConcat, {InferImplConcat, nullptr, true}}, {prim::kPrimArgMaxWithValue, {InferImplArgMaxWithValue, nullptr, true}}, diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index a414e30f50f..b1ad86808d6 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -74,6 +74,7 @@ constexpr auto kReLUGradV2 = "ReluGradV2"; constexpr auto kGeLUGrad = "GeLUGrad"; constexpr auto kFastGeLU = "FastGeLU"; constexpr auto kFastGeLUGrad = "FastGeLUGrad"; +constexpr auto kStridedSlice = "StridedSlice"; constexpr auto kZerosLike = "ZerosLike"; constexpr auto kOnesLike = "OnesLike"; constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs"; @@ -165,7 +166,7 @@ inline const PrimitivePtr kPrimGatherND = std::make_shared("GatherND" inline const PrimitivePtr kPrimSparseGatherV2 = std::make_shared("SparseGatherV2"); inline const PrimitivePtr kPrimSparseToDense = std::make_shared("SparseToDense"); inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); -inline const PrimitivePtr kPrimStridedSlice = std::make_shared("StridedSlice"); +inline const PrimitivePtr kPrimStridedSlice = std::make_shared(kStridedSlice); inline const PrimitivePtr kPrimDynamicShape = std::make_shared("DynamicShape"); inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); diff --git a/mindspore/core/ops/strided_slice.cc b/mindspore/core/ops/strided_slice.cc index 510b8decc3e..33360004737 100644 --- a/mindspore/core/ops/strided_slice.cc +++ b/mindspore/core/ops/strided_slice.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,233 @@ namespace mindspore { namespace ops { namespace { -std::vector TenToTwo(int64_t num) { +void EllipsisInferShape(const PrimitivePtr &primitive, const std::vector &x_shape, + const std::vector &begin_v, const std::vector &end_v, + const std::vector &strides_v, std::vector *infer_shape, size_t i, size_t j, + bool has_ellipsis) { + if (!has_ellipsis) { + return; + } + auto strided_slice_prim = primitive->cast(); + size_t x_rank = x_shape.size(); + size_t slice_len = begin_v.size(); + std::vector begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask()); + std::vector end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask()); + std::vector new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask()); + std::vector shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask()); + + int64_t num = 0; + for (size_t n = j + 1; n < slice_len; n++) { + if (new_axis_pos[n] == 1) { + num++; + } + } + + int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num; + infer_shape->insert(infer_shape->end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims); + j += 1; + i += ellipsis_occupied_dims; + + while (i < x_rank || j < slice_len) { + int64_t x_dim_size = x_shape[i]; + int64_t start = begin_v[j]; + int64_t finish = end_v[j]; + int64_t strides = strides_v[j]; + if (j < begin_pos.size() || j < slice_len) { + start = strides_v[j] < 0 ? -1 : 0; + } + if (j < end_pos.size() && end_pos[j] == 1) { + finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i]; + } + if (j < new_axis_pos.size() && new_axis_pos[j] == 1) { + infer_shape->push_back(1); + j += 1; + continue; + } + if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) { + if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) { + MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number"; + } + j += 1; + i += 1; + continue; + } + int64_t slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size); + infer_shape->push_back(slicing_length); + i += 1; + j += 1; + } + return; +} + +const std::vector CheckAndGetValidStrides(const AbstractBasePtr &stride_arg) { + MS_EXCEPTION_IF_NULL(stride_arg); + auto temp_strides = stride_arg->cast()->BuildValue(); + MS_EXCEPTION_IF_NULL(temp_strides); + auto strides = GetValue>(temp_strides); + if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) { return stride == 0; })) { + MS_EXCEPTION(ValueError) << "StridedSlice's input strides cannot contain 0."; + } + return strides; +} + +std::vector ComputeInferShape(const PrimitivePtr &primitive, const std::vector &begin_v, + const std::vector &end_v, const std::vector &x_shape, + const std::vector &strides_v, const std::vector &begin_pos, + const std::vector &shrink_axis_pos, const std::vector &end_pos, + const std::vector &new_axis_pos, + const std::vector &ellipsis_pos) { + size_t i = 0; + size_t j = 0; + int64_t start; + int64_t finish; + int64_t strides; + int64_t slicing_length; + bool has_ellipsis = false; + std::vector infer_shape; + size_t slice_len = begin_v.size(); + size_t x_rank = x_shape.size(); + while (i < x_rank || j < slice_len) { + int64_t x_dim_size = x_shape[i]; + if (j < slice_len) { + start = begin_v[j]; + finish = end_v[j]; + strides = strides_v[j]; + if (j < ellipsis_pos.size() && ellipsis_pos[j] == 1) { + has_ellipsis = true; + break; + } + if (j < begin_pos.size() && begin_pos[j] == 1) { + start = strides_v[j] < 0 ? -1 : 0; + } + if (j < end_pos.size() && end_pos[j] == 1) { + finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i]; + } + if (j < new_axis_pos.size() && new_axis_pos[j] == 1) { + infer_shape.push_back(1); + j += 1; + continue; + } + if (j < shrink_axis_pos.size() && shrink_axis_pos[j] == 1) { + if ((-x_shape[i] <= start && start < x_shape[i]) || strides < 0) { + MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number"; + } + j += 1; + i += 1; + continue; + } + } else { + start = 0; + finish = x_shape[0]; + strides = 1; + } + auto strided_slice_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_prim); + slicing_length = strided_slice_prim->compute_slicing_length(start, finish, strides, x_dim_size); + infer_shape.push_back(slicing_length); + i += 1; + j += 1; + } + EllipsisInferShape(primitive, x_shape, begin_v, end_v, strides_v, &infer_shape, i, j, has_ellipsis); + return infer_shape; +} + +abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto strided_slice_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_prim); + auto temp_begin_v = input_args[1]->cast()->BuildValue(); + auto begin_v = GetValue>(temp_begin_v); + auto temp_end_v = input_args[2]->cast()->BuildValue(); + auto end_v = GetValue>(temp_end_v); + auto strides_v = CheckAndGetValidStrides(input_args[3]); + + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + auto min_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape]; + auto max_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape]; + std::vector begin_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_begin_mask()); + std::vector end_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_end_mask()); + std::vector ellipsis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_ellipsis_mask()); + std::vector new_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_new_axis_mask()); + std::vector shrink_axis_pos = strided_slice_prim->TenToTwo(strided_slice_prim->get_shrink_axis_mask()); + auto ret_in_shape = ComputeInferShape(primitive, begin_v, end_v, x_shape, strides_v, begin_pos, shrink_axis_pos, + end_pos, new_axis_pos, ellipsis_pos); + if (min_shape.empty() || max_shape.empty()) { + return std::make_shared(ret_in_shape); + } + auto ret_min_shape = ComputeInferShape(primitive, begin_v, end_v, min_shape, strides_v, begin_pos, shrink_axis_pos, + end_pos, new_axis_pos, ellipsis_pos); + auto ret_max_shape = ComputeInferShape(primitive, begin_v, end_v, max_shape, strides_v, begin_pos, shrink_axis_pos, + end_pos, new_axis_pos, ellipsis_pos); + return std::make_shared(ret_in_shape, ret_min_shape, ret_max_shape); +} + +TypePtr StridedSliceInferType(const std::vector &input_args) { + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto infer_type = input_args[0]->BuildType()->cast()->element(); + return infer_type; +} +} // namespace + +void StridedSlice::set_begin_mask(const int64_t begin_mask) { + CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name()); + this->AddAttr(kBeginMask, MakeValue(begin_mask)); +} +int64_t StridedSlice::get_begin_mask() const { + auto value_ptr = GetAttr(kBeginMask); + return GetValue(value_ptr); +} +void StridedSlice::set_end_mask(const int64_t end_mask) { + CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name()); + this->AddAttr(kEndMask, MakeValue(end_mask)); +} +int64_t StridedSlice::get_end_mask() const { + auto value_ptr = GetAttr(kEndMask); + return GetValue(value_ptr); +} +void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) { + CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name()); + std::bitset bs(ellipsis_mask); + std::ostringstream buffer; + if (bs.count() > 1) { + buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask(); + MS_EXCEPTION(ValueError) << buffer.str(); + } + this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask)); +} +int64_t StridedSlice::get_ellipsis_mask() const { + auto value_ptr = GetAttr(kEllipsisMask); + return GetValue(value_ptr); +} +void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) { + CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name()); + this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask)); +} +int64_t StridedSlice::get_new_axis_mask() const { + auto value_ptr = GetAttr(kNewAxisMask); + return GetValue(value_ptr); +} +void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) { + CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name()); + this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask)); +} +int64_t StridedSlice::get_shrink_axis_mask() const { + auto value_ptr = GetAttr(kShrinkAxisMask); + return GetValue(value_ptr); +} +void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask, + const int64_t new_axis_mask, const int64_t shrink_axis_mask) { + this->set_begin_mask(begin_mask); + this->set_end_mask(end_mask); + this->set_ellipsis_mask(ellipsis_mask); + this->set_new_axis_mask(new_axis_mask); + this->set_shrink_axis_mask(shrink_axis_mask); +} + +std::vector StridedSlice::TenToTwo(int64_t num) { std::vector output; if (num == 0) { output.push_back(0); @@ -42,13 +268,7 @@ std::vector TenToTwo(int64_t num) { return output; } -int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector x_shape, - int64_t i) { - if (i > (int64_t)x_shape.size()) { - MS_EXCEPTION(ValueError) << "For 'StridedSlice', When their is no new axis, " - "the index length must be less or equal than the dim of x."; - } - int64_t x_dim = x_shape[i]; +int64_t StridedSlice::compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const { int64_t slicing_length = 0; if (strides > 0) { if ((start_pos >= x_dim) || end_pos < -x_dim) { @@ -97,183 +317,13 @@ int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strid } return slicing_length; } -abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto prim_name = primitive->name(); - auto temp_begin_v = input_args[1]->cast()->BuildValue(); - auto begin_v = GetValue>(temp_begin_v); - auto temp_end_v = input_args[2]->cast()->BuildValue(); - auto end_v = GetValue>(temp_end_v); - auto temp_strides_v = input_args[3]->cast()->BuildValue(); - auto strides_v = GetValue>(temp_strides_v); - - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - int64_t x_rank = x_shape.size(); - int64_t slice_len = begin_v.size(); - std::vector begin_pos = TenToTwo(GetValue(primitive->GetAttr(kBeginMask))); - std::vector end_pos = TenToTwo(GetValue(primitive->GetAttr(kEndMask))); - std::vector ellipsis_pos = TenToTwo(GetValue(primitive->GetAttr(kEllipsisMask))); - std::vector new_axis_pos = TenToTwo(GetValue(primitive->GetAttr(kNewAxisMask))); - std::vector shrink_axis_pos = TenToTwo(GetValue(primitive->GetAttr(kShrinkAxisMask))); - - int64_t i = 0; - int64_t j = 0; - int64_t start; - int64_t finish; - int64_t strides; - int64_t slicing_length; - bool has_ellipsis = false; - std::vector infer_shape; - while (i < x_rank || j < slice_len) { - if (j < slice_len) { - start = begin_v[j]; - finish = end_v[j]; - strides = strides_v[j]; - if (j < (int64_t)ellipsis_pos.size() && ellipsis_pos[j] == 1) { - has_ellipsis = true; - break; - } - if (j < (int64_t)begin_pos.size() && begin_pos[j] == 1) { - start = strides_v[j] < 0 ? -1 : 0; - } - if (j < (int64_t)end_pos.size() && end_pos[j] == 1) { - finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i]; - } - if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) { - infer_shape.push_back(1); - j += 1; - continue; - } - if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) { - if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) { - MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number"; - } - j += 1; - i += 1; - continue; - } - } else { - start = 0; - finish = x_shape[0]; - strides = 1; - } - slicing_length = compute_slicing_length(start, finish, strides, x_shape, i); - infer_shape.push_back(slicing_length); - i += 1; - j += 1; - } - - int64_t num = 0; - for (int64_t n = j + 1; n < slice_len; n++) { - if (new_axis_pos[n] == 1) { - num++; - } - } - if (has_ellipsis) { - int64_t ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + num; - infer_shape.insert(infer_shape.end(), x_shape.begin() + i, x_shape.begin() + i + ellipsis_occupied_dims); - j += 1; - i += ellipsis_occupied_dims; - - while (i < x_rank || j < slice_len) { - start = begin_v[j]; - finish = end_v[j]; - strides = strides_v[j]; - if (j < (int64_t)begin_pos.size() || j < slice_len) { - start = strides_v[j] < 0 ? -1 : 0; - } - if (j < (int64_t)end_pos.size() && end_pos[j] == 1) { - finish = strides_v[j] < 0 ? -(x_shape[i] + 1) : x_shape[i]; - } - if (j < (int64_t)new_axis_pos.size() && new_axis_pos[j] == 1) { - infer_shape.push_back(1); - j += 1; - continue; - } - if (j < (int64_t)shrink_axis_pos.size() && shrink_axis_pos[j] == 1) { - if (((-x_shape[i] <= start && start < x_shape[i]) == false) || strides < 0) { - MS_EXCEPTION(ValueError) << "when shrink axis, the stride cannot be negative number"; - } - j += 1; - i += 1; - continue; - } - slicing_length = compute_slicing_length(start, finish, strides, x_shape, i); - infer_shape.push_back(slicing_length); - i += 1; - j += 1; - } - } - return std::make_shared(infer_shape); -} - -TypePtr StridedSliceInferType(const PrimitivePtr &prim, const std::vector &input_args) { - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } - auto infer_type = input_args[0]->BuildType()->cast()->element(); - return infer_type; -} -} // namespace - -void StridedSlice::set_begin_mask(const int64_t begin_mask) { - CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name()); - this->AddAttr(kBeginMask, MakeValue(begin_mask)); -} -int64_t StridedSlice::get_begin_mask() const { return GetValue(GetAttr(kBeginMask)); } -void StridedSlice::set_end_mask(const int64_t end_mask) { - CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name()); - this->AddAttr(kEndMask, MakeValue(end_mask)); -} -int64_t StridedSlice::get_end_mask() const { - auto value_ptr = GetAttr(kEndMask); - return GetValue(value_ptr); -} -void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) { - CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name()); - std::bitset bs(ellipsis_mask); - std::ostringstream buffer; - if (bs.count() > 1) { - buffer << "For" << this->name() << ", only support one ellipsis in the index, but got " << this->get_end_mask(); - MS_EXCEPTION(ValueError) << buffer.str(); - } - this->AddAttr(kEllipsisMask, MakeValue(ellipsis_mask)); -} -int64_t StridedSlice::get_ellipsis_mask() const { - auto value_ptr = GetAttr(kEllipsisMask); - return GetValue(value_ptr); -} -void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) { - CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name()); - this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask)); -} -int64_t StridedSlice::get_new_axis_mask() const { - auto value_ptr = GetAttr(kNewAxisMask); - return GetValue(value_ptr); -} -void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) { - CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name()); - this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask)); -} -int64_t StridedSlice::get_shrink_axis_mask() const { - auto value_ptr = GetAttr(kShrinkAxisMask); - return GetValue(value_ptr); -} -void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask, - const int64_t new_axis_mask, const int64_t shrink_axis_mask) { - this->set_begin_mask(begin_mask); - this->set_end_mask(end_mask); - this->set_ellipsis_mask(ellipsis_mask); - this->set_new_axis_mask(new_axis_mask); - this->set_shrink_axis_mask(shrink_axis_mask); -} AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - return std::make_shared(StridedSliceInferType(primitive, input_args), + return std::make_shared(StridedSliceInferType(input_args), StridedSliceInferShape(primitive, input_args)->shape()); } REGISTER_PRIMITIVE_C(kNameStridedSlice, StridedSlice); + } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/strided_slice.h b/mindspore/core/ops/strided_slice.h index 8f6739713e3..dcbb0ba66a1 100644 --- a/mindspore/core/ops/strided_slice.h +++ b/mindspore/core/ops/strided_slice.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,10 +26,12 @@ namespace mindspore { namespace ops { -constexpr auto kNameStridedSlice = "StridedSlice"; +constexpr auto kNameStridedSlice = prim::kStridedSlice; class StridedSlice : public PrimitiveC { public: - StridedSlice() : PrimitiveC(kNameStridedSlice) { InitIOName({"x", "begin", "end", "strides"}, {"output"}); } + StridedSlice() : PrimitiveC(prim::kPrimStridedSlice->name()) { + InitIOName({"x", "begin", "end", "strides"}, {"output"}); + } ~StridedSlice() = default; MS_DECLARE_PARENT(StridedSlice, PrimitiveC); void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0, @@ -45,8 +47,10 @@ class StridedSlice : public PrimitiveC { int64_t get_new_axis_mask() const; int64_t get_shrink_axis_mask() const; std::vector TenToTwo(int64_t num); - int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, std::vector x_shape, - int64_t i); + int64_t compute_slicing_length(int64_t start_pos, int64_t end_pos, int64_t strides, int64_t x_dim) const; +}; +struct ComputeHasEllipsis { + bool has_ellipsis; }; AbstractBasePtr StridedSliceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args);