From 63d3d8b6863b0a414db10f6beb7cea5846129088 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 1 Dec 2020 14:38:30 +0800 Subject: [PATCH] fix strided slice multi inputs bug --- .../runtime/kernel/arm/base/strided_slice.cc | 45 ------------------- .../runtime/kernel/arm/base/strided_slice.h | 3 -- .../tflite_inputs_order_exchange_pass.cc | 1 - 3 files changed, 49 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc index 6fc8d8b9c81..a1d19b3a978 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.cc @@ -29,12 +29,6 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_StridedSlice; namespace mindspore::kernel { -namespace { -constexpr size_t kMultiInputsSize = 4; -constexpr size_t kBeginsIndex = 1; -constexpr size_t kEndsIndex = 2; -constexpr size_t kStridesInex = 3; -} // namespace int StridedSliceCPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -57,38 +51,6 @@ int StridedSliceCPUKernel::ReSize() { return RET_OK; } -int StridedSliceCPUKernel::HandleMultiInputs() { - if (in_tensors_.size() != kMultiInputsSize) { - MS_LOG(ERROR) << "Inputs size should be " << kMultiInputsSize << ", got " << in_tensors_.size(); - return RET_ERROR; - } - if (param_ == nullptr) { - MS_LOG(ERROR) << "StridedSliceParamater cast nullptr"; - return RET_ERROR; - } - auto begins = in_tensors_.at(kBeginsIndex); - MS_ASSERT(begins != nullptr); - int axis_num = begins->ElementsNum(); - if (axis_num > DIMENSION_6D) { - MS_LOG(ERROR) << "StridedSlice supports max dimension " << DIMENSION_6D << ", input begins dim is " << axis_num; - return RET_ERROR; - } - memcpy(param_->begins_, begins->MutableData(), axis_num * sizeof(int)); - - auto ends = in_tensors_.at(kEndsIndex); - MS_ASSERT(ends != nullptr); - MS_ASSERT(axis_num == ends->ElementsNum()); - memcpy(param_->ends_, ends->MutableData(), axis_num * sizeof(int)); - - auto strides = in_tensors_.at(kStridesInex); - MS_ASSERT(strides != nullptr); - MS_ASSERT(axis_num == strides->ElementsNum()); - memcpy(param_->strides_, strides->MutableData(), axis_num * sizeof(int)); - - param_->num_axes_ = axis_num; - return RET_OK; -} - int StridedSliceCPUKernel::Run() { auto input = in_tensors_.at(0); MS_ASSERT(input); @@ -108,13 +70,6 @@ int StridedSliceCPUKernel::Run() { } auto output = out_tensors_.at(0); MS_ASSERT(output); - // inputs order: input, begin, end, stride - if (in_tensors().size() == kMultiInputsSize) { - auto ret = HandleMultiInputs(); - if (ret != RET_OK) { - return ret; - } - } auto ret = DoStridedSlice(input->MutableData(), output->MutableData(), param_); if (ret != RET_OK) { MS_LOG(ERROR) << "StridedSlice error error_code[" << ret << "]"; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h index 0a1751eda5b..0de0becec28 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/strided_slice.h @@ -36,9 +36,6 @@ class StridedSliceCPUKernel : public LiteKernel { int ReSize() override; int Run() override; - private: - int HandleMultiInputs(); - private: StridedSliceParameter *param_; }; diff --git a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc index e0d7db5c87b..30acda2f043 100644 --- a/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/tflite_inputs_order_exchange_pass.cc @@ -60,7 +60,6 @@ bool TfliteInputsOrderExchangePass::Run(const FuncGraphPtr &graph) { } if (opt::GetCNodeType(node) == schema::PrimitiveType_Reduce || - opt::GetCNodeType(node) == schema::PrimitiveType_StridedSlice || opt::GetCNodeType(node) == schema::PrimitiveType_ArgMin || opt::GetCNodeType(node) == schema::PrimitiveType_ArgMax || opt::GetCNodeType(node) == schema::PrimitiveType_SpaceToBatch ||