From babff262e3f38143541bff1868717fa7ad956267 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Sat, 22 Aug 2020 18:17:56 +0800 Subject: [PATCH] stack support int32 --- mindspore/lite/nnacl/fp32/stack.c | 30 +++++++++++++++++-- mindspore/lite/nnacl/fp32/stack.h | 2 ++ mindspore/lite/src/ops/stack.cc | 14 +++++---- .../lite/src/runtime/kernel/arm/fp32/stack.cc | 20 +++++++++---- 4 files changed, 53 insertions(+), 13 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/stack.c b/mindspore/lite/nnacl/fp32/stack.c index 70a403791db..28182b0b848 100644 --- a/mindspore/lite/nnacl/fp32/stack.c +++ b/mindspore/lite/nnacl/fp32/stack.c @@ -17,7 +17,7 @@ #include "nnacl/fp32/stack.h" #include "nnacl/arithmetic_common.h" -void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { +size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) { size_t one_input_size = 1; for (size_t i = 0; i < shape_size; ++i) { one_input_size *= in_shape[i]; @@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t ComputeStrides(in_shape, in_strides, shape_size); size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size; - size_t copy_size = copy_num * sizeof(float); + return copy_num; +} + +size_t GetStackPreAxisCount(const int *in_shape, int axis) { size_t pre_axis_count = 1; for (size_t i = 0; i < axis; ++i) { pre_axis_count *= in_shape[i]; } + return pre_axis_count; +} + +void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { + size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size); + size_t copy_size = copy_num * sizeof(float); + size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis); + size_t in_offset = 0; + size_t out_offset = 0; + for (size_t i = 0; i < pre_axis_count; ++i) { + for (size_t j = 0; j < input_num; ++j) { + memcpy(output + out_offset, inputs[j] + in_offset, copy_size); + out_offset += copy_num; + } + in_offset += copy_num; + } +} + +void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, + int32_t *output) { + size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size); + size_t copy_size = copy_num * sizeof(int32_t); + size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis); size_t in_offset = 0; size_t out_offset = 0; for (size_t i = 0; i < pre_axis_count; ++i) { diff --git a/mindspore/lite/nnacl/fp32/stack.h b/mindspore/lite/nnacl/fp32/stack.h index 2bc8ed8af0f..652d4263d8e 100644 --- a/mindspore/lite/nnacl/fp32/stack.h +++ b/mindspore/lite/nnacl/fp32/stack.h @@ -27,6 +27,8 @@ typedef struct StackParameter { extern "C" { #endif void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output); +void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, + int32_t *output); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/ops/stack.cc b/mindspore/lite/src/ops/stack.cc index 17f7c05ac31..4a1d5f02736 100644 --- a/mindspore/lite/src/ops/stack.cc +++ b/mindspore/lite/src/ops/stack.cc @@ -56,7 +56,8 @@ int Stack::InferShape(std::vector inputs, std::vectorset_data_type(input->data_type()); + auto input0_data_type = input->data_type(); + outputs[0]->set_data_type(input0_data_type); outputs[0]->SetFormat(input->GetFormat()); if (!GetInferFlag()) { return RET_OK; @@ -69,12 +70,8 @@ int Stack::InferShape(std::vector inputs, std::vectorGetFormat(); + for (size_t i = 1; i < inputs.size(); ++i) { - if (inputs[i]->GetFormat() != input0_format) { - MS_LOG(ERROR) << "All inputs should have the same format!"; - return RET_PARAM_INVALID; - } auto input_shape_tmp = inputs[i]->shape(); if (input_shape_tmp.size() != input_shape.size()) { MS_LOG(ERROR) << "All input shape size should be the same!"; @@ -86,6 +83,11 @@ int Stack::InferShape(std::vector inputs, std::vectordata_type() != input0_data_type) { + MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = " + << inputs[i]->data_type(); + return RET_PARAM_INVALID; + } } output_shape.insert(output_shape.begin() + axis, inputs.size()); outputs[0]->set_shape(output_shape); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc index 81ae87979b7..a6152c2323a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/stack.cc @@ -49,12 +49,21 @@ int StackCPUKernel::Run() { } size_t inputs_num = in_tensors_.size(); auto input0_shape = in_tensors_[0]->shape(); - auto *output_data = reinterpret_cast(out_tensors_[0]->Data()); - float *inputs[inputs_num]; - for (size_t i = 0; i < inputs_num; ++i) { - inputs[i] = reinterpret_cast(in_tensors_[i]->Data()); + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { + auto *output_data = reinterpret_cast(out_tensors_[0]->Data()); + float *inputs[inputs_num]; + for (size_t i = 0; i < inputs_num; ++i) { + inputs[i] = reinterpret_cast(in_tensors_[i]->Data()); + } + DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); + } else { + auto *output_data = reinterpret_cast(out_tensors_[0]->Data()); + int32_t *inputs[inputs_num]; + for (size_t i = 0; i < inputs_num; ++i) { + inputs[i] = reinterpret_cast(in_tensors_[i]->Data()); + } + DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); } - DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data); return RET_OK; } @@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector