stack support int32

This commit is contained in:
chenjianping 2020-08-22 18:17:56 +08:00
parent 50877b586d
commit babff262e3
4 changed files with 53 additions and 13 deletions

View File

@ -17,7 +17,7 @@
#include "nnacl/fp32/stack.h" #include "nnacl/fp32/stack.h"
#include "nnacl/arithmetic_common.h" #include "nnacl/arithmetic_common.h"
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) { size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
size_t one_input_size = 1; size_t one_input_size = 1;
for (size_t i = 0; i < shape_size; ++i) { for (size_t i = 0; i < shape_size; ++i) {
one_input_size *= in_shape[i]; one_input_size *= in_shape[i];
@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t
ComputeStrides(in_shape, in_strides, shape_size); ComputeStrides(in_shape, in_strides, shape_size);
size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size; size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
size_t copy_size = copy_num * sizeof(float); return copy_num;
}
size_t GetStackPreAxisCount(const int *in_shape, int axis) {
size_t pre_axis_count = 1; size_t pre_axis_count = 1;
for (size_t i = 0; i < axis; ++i) { for (size_t i = 0; i < axis; ++i) {
pre_axis_count *= in_shape[i]; pre_axis_count *= in_shape[i];
} }
return pre_axis_count;
}
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
size_t copy_size = copy_num * sizeof(float);
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
size_t in_offset = 0;
size_t out_offset = 0;
for (size_t i = 0; i < pre_axis_count; ++i) {
for (size_t j = 0; j < input_num; ++j) {
memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
out_offset += copy_num;
}
in_offset += copy_num;
}
}
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output) {
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
size_t copy_size = copy_num * sizeof(int32_t);
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
size_t in_offset = 0; size_t in_offset = 0;
size_t out_offset = 0; size_t out_offset = 0;
for (size_t i = 0; i < pre_axis_count; ++i) { for (size_t i = 0; i < pre_axis_count; ++i) {

View File

@ -27,6 +27,8 @@ typedef struct StackParameter {
extern "C" { extern "C" {
#endif #endif
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output); void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -56,7 +56,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
auto input = inputs.at(0); auto input = inputs.at(0);
outputs[0]->set_data_type(input->data_type()); auto input0_data_type = input->data_type();
outputs[0]->set_data_type(input0_data_type);
outputs[0]->SetFormat(input->GetFormat()); outputs[0]->SetFormat(input->GetFormat());
if (!GetInferFlag()) { if (!GetInferFlag()) {
return RET_OK; return RET_OK;
@ -69,12 +70,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
MS_LOG(ERROR) << "Invalid axis " << GetAxis(); MS_LOG(ERROR) << "Invalid axis " << GetAxis();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
schema::Format input0_format = input->GetFormat();
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
if (inputs[i]->GetFormat() != input0_format) {
MS_LOG(ERROR) << "All inputs should have the same format!";
return RET_PARAM_INVALID;
}
auto input_shape_tmp = inputs[i]->shape(); auto input_shape_tmp = inputs[i]->shape();
if (input_shape_tmp.size() != input_shape.size()) { if (input_shape_tmp.size() != input_shape.size()) {
MS_LOG(ERROR) << "All input shape size should be the same!"; MS_LOG(ERROR) << "All input shape size should be the same!";
@ -86,6 +83,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
} }
if (inputs[i]->data_type() != input0_data_type) {
MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = "
<< inputs[i]->data_type();
return RET_PARAM_INVALID;
}
} }
output_shape.insert(output_shape.begin() + axis, inputs.size()); output_shape.insert(output_shape.begin() + axis, inputs.size());
outputs[0]->set_shape(output_shape); outputs[0]->set_shape(output_shape);

View File

@ -49,12 +49,21 @@ int StackCPUKernel::Run() {
} }
size_t inputs_num = in_tensors_.size(); size_t inputs_num = in_tensors_.size();
auto input0_shape = in_tensors_[0]->shape(); auto input0_shape = in_tensors_[0]->shape();
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data()); if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
float *inputs[inputs_num]; auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
for (size_t i = 0; i < inputs_num; ++i) { float *inputs[inputs_num];
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data()); for (size_t i = 0; i < inputs_num; ++i) {
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
}
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
} else {
auto *output_data = reinterpret_cast<int32_t *>(out_tensors_[0]->Data());
int32_t *inputs[inputs_num];
for (size_t i = 0; i < inputs_num; ++i) {
inputs[i] = reinterpret_cast<int32_t *>(in_tensors_[i]->Data());
}
DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
} }
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
return RET_OK; return RET_OK;
} }
@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel