From e6766924cb7379f00a0a59ed42147522bebec8c2 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Fri, 11 Sep 2020 16:30:07 +0800 Subject: [PATCH] support batch_to_space_nd,fix stack minus axis bug --- mindspore/lite/src/ops/primitive_c.cc | 2 + mindspore/lite/src/ops/stack.cc | 2 +- mindspore/lite/src/populate_parameter.cc | 1 + .../kernel/arm/base/batch_to_space_base.cc | 16 +++++-- .../lite/src/runtime/kernel/arm/fp32/stack.cc | 2 +- .../kernel/arm/fp32/stack_fp32_test.cc | 48 +++++++++++++++++++ 6 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 912db6aef2f..55e19d0b3f7 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -598,6 +598,7 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { case schema::PrimitiveType_Unsqueeze: return new Unsqueeze(primitive); case schema::PrimitiveType_BatchToSpace: + case schema::PrimitiveType_BatchToSpaceND: return new BatchToSpace(primitive); case schema::PrimitiveType_SpaceToBatch: return new SpaceToBatch(primitive); @@ -857,6 +858,7 @@ PrimitiveC *PrimitiveC::Create(const schema::Primitive *primitive) { case schema::PrimitiveType_Unsqueeze: return NewPrimitiveC(primitive); case schema::PrimitiveType_BatchToSpace: + case schema::PrimitiveType_BatchToSpaceND: return NewPrimitiveC(primitive); case schema::PrimitiveType_SpaceToBatch: return NewPrimitiveC(primitive); diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index 208c79a1b0d..c7752ada1ad 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -80,7 +80,7 @@ int Stack::InferShape(std::vector inputs, std::vector output auto input_shape = input->shape(); std::vector output_shape = input_shape; - auto axis = GetAxis() < 0 ? GetAxis() + input_shape.size() : GetAxis(); + auto axis = GetAxis() < 0 ? GetAxis() + input_shape.size() + 1 : GetAxis(); if (axis < 0 || axis > input_shape.size()) { MS_LOG(ERROR) << "Invalid axis " << GetAxis(); return RET_PARAM_INVALID; diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 5ad708b3165..e39c4c94bec 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1669,6 +1669,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() { populate_parameter_funcs_[schema::PrimitiveType_Pad] = PopulatePadParameter; populate_parameter_funcs_[schema::PrimitiveType_Resize] = PopulateResizeParameter; populate_parameter_funcs_[schema::PrimitiveType_BatchToSpace] = PopulateBatchToSpaceParameter; + populate_parameter_funcs_[schema::PrimitiveType_BatchToSpaceND] = PopulateBatchToSpaceParameter; populate_parameter_funcs_[schema::PrimitiveType_SpaceToDepth] = PopulateSpaceToDepthParameter; populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatch] = PopulateSpaceToBatchParameter; populate_parameter_funcs_[schema::PrimitiveType_SpaceToBatchND] = PopulateSpaceToBatchNDParameter; diff --git a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc index 96fd47cb5ea..e7fdfee2f8f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/batch_to_space_base.cc @@ -27,9 +27,14 @@ using mindspore::lite::RET_ERROR; using mindspore::lite::RET_FORMAT_ERR; using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_BatchToSpace; +using mindspore::schema::PrimitiveType_BatchToSpaceND; namespace mindspore::kernel { int BatchToSpaceBaseCPUKernel::Init() { + if (in_tensors_[0]->GetFormat() != schema::Format::Format_NHWC) { + MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; + return RET_FORMAT_ERR; + } BatchToSpaceParameter *param = reinterpret_cast(this->op_parameter_); for (int i = 0; i < BATCH_TO_SPACE_CROPS_SIZE; ++i) { if (param->crops_[i] != 0) { @@ -40,9 +45,10 @@ int BatchToSpaceBaseCPUKernel::Init() { } int BatchToSpaceBaseCPUKernel::ReSize() { - if (in_tensors_[0]->GetFormat() != schema::Format::Format_NHWC) { - MS_LOG(ERROR) << "batch_to_space only support NHWC now!"; - return RET_FORMAT_ERR; + auto shape = in_tensors_[0]->shape(); + if (shape.size() != 4) { + MS_LOG(ERROR) << "Unsupport shape size: " << shape.size(); + return RET_ERROR; } return RET_OK; } @@ -52,7 +58,6 @@ kernel::LiteKernel *CpuBatchToSpaceInt8KernelCreator(const std::vector(op_parameter_); auto input0_shape = in_tensors_[0]->shape(); - axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() : param->axis_; + axis_ = param->axis_ < 0 ? param->axis_ + input0_shape.size() + 1 : param->axis_; return RET_OK; } diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc new file mode 100644 index 00000000000..8bf5b6b836c --- /dev/null +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/stack_fp32_test.cc @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "common/common_test.h" +#include "mindspore/lite/nnacl/fp32/stack.h" + +namespace mindspore { +class StackTestFp32 : public mindspore::CommonTest { + public: + StackTestFp32() = default; +}; + +TEST_F(StackTestFp32, StackTest1) { + float input0[6] = {1, 2, 3, 10, 20, 30}; + float input1[6] = {4, 5, 6, 40, 50, 60}; + float input2[6] = {7, 8, 9, 70, 80, 90}; + float *input[3]; + input[0] = input0; + input[1] = input1; + input[2] = input2; + std::vector shape = {2, 3}; + int axis = 2; + constexpr int kOutSize = 18; + float expect_out[kOutSize] = {1, 4, 7, 2, 5, 8, 3, 6, 9, + 10, 40, 70, 20, 50, 80, 30, 60, 90}; + float output[kOutSize]; + DoStack(input, 3, shape.data(), shape.size(), axis, output); + for (int i = 0; i < kOutSize; ++i) { + std::cout << output[i] << " "; + } + std::cout << "\n"; + CompareOutputData(output, expect_out, kOutSize, 0.000001); +} + + +} // namespace mindspore