forked from mindspore-Ecosystem/mindspore
stack support int32
This commit is contained in:
parent
50877b586d
commit
babff262e3
|
@ -17,7 +17,7 @@
|
||||||
#include "nnacl/fp32/stack.h"
|
#include "nnacl/fp32/stack.h"
|
||||||
#include "nnacl/arithmetic_common.h"
|
#include "nnacl/arithmetic_common.h"
|
||||||
|
|
||||||
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
|
size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
|
||||||
size_t one_input_size = 1;
|
size_t one_input_size = 1;
|
||||||
for (size_t i = 0; i < shape_size; ++i) {
|
for (size_t i = 0; i < shape_size; ++i) {
|
||||||
one_input_size *= in_shape[i];
|
one_input_size *= in_shape[i];
|
||||||
|
@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t
|
||||||
ComputeStrides(in_shape, in_strides, shape_size);
|
ComputeStrides(in_shape, in_strides, shape_size);
|
||||||
|
|
||||||
size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
|
size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
|
||||||
size_t copy_size = copy_num * sizeof(float);
|
return copy_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t GetStackPreAxisCount(const int *in_shape, int axis) {
|
||||||
size_t pre_axis_count = 1;
|
size_t pre_axis_count = 1;
|
||||||
for (size_t i = 0; i < axis; ++i) {
|
for (size_t i = 0; i < axis; ++i) {
|
||||||
pre_axis_count *= in_shape[i];
|
pre_axis_count *= in_shape[i];
|
||||||
}
|
}
|
||||||
|
return pre_axis_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
|
||||||
|
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
|
||||||
|
size_t copy_size = copy_num * sizeof(float);
|
||||||
|
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
|
||||||
|
size_t in_offset = 0;
|
||||||
|
size_t out_offset = 0;
|
||||||
|
for (size_t i = 0; i < pre_axis_count; ++i) {
|
||||||
|
for (size_t j = 0; j < input_num; ++j) {
|
||||||
|
memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
|
||||||
|
out_offset += copy_num;
|
||||||
|
}
|
||||||
|
in_offset += copy_num;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
|
||||||
|
int32_t *output) {
|
||||||
|
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
|
||||||
|
size_t copy_size = copy_num * sizeof(int32_t);
|
||||||
|
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
|
||||||
size_t in_offset = 0;
|
size_t in_offset = 0;
|
||||||
size_t out_offset = 0;
|
size_t out_offset = 0;
|
||||||
for (size_t i = 0; i < pre_axis_count; ++i) {
|
for (size_t i = 0; i < pre_axis_count; ++i) {
|
||||||
|
|
|
@ -27,6 +27,8 @@ typedef struct StackParameter {
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
|
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
|
||||||
|
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
|
||||||
|
int32_t *output);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -56,7 +56,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
auto input = inputs.at(0);
|
auto input = inputs.at(0);
|
||||||
outputs[0]->set_data_type(input->data_type());
|
auto input0_data_type = input->data_type();
|
||||||
|
outputs[0]->set_data_type(input0_data_type);
|
||||||
outputs[0]->SetFormat(input->GetFormat());
|
outputs[0]->SetFormat(input->GetFormat());
|
||||||
if (!GetInferFlag()) {
|
if (!GetInferFlag()) {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -69,12 +70,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
||||||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
schema::Format input0_format = input->GetFormat();
|
|
||||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||||
if (inputs[i]->GetFormat() != input0_format) {
|
|
||||||
MS_LOG(ERROR) << "All inputs should have the same format!";
|
|
||||||
return RET_PARAM_INVALID;
|
|
||||||
}
|
|
||||||
auto input_shape_tmp = inputs[i]->shape();
|
auto input_shape_tmp = inputs[i]->shape();
|
||||||
if (input_shape_tmp.size() != input_shape.size()) {
|
if (input_shape_tmp.size() != input_shape.size()) {
|
||||||
MS_LOG(ERROR) << "All input shape size should be the same!";
|
MS_LOG(ERROR) << "All input shape size should be the same!";
|
||||||
|
@ -86,6 +83,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
||||||
return RET_PARAM_INVALID;
|
return RET_PARAM_INVALID;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (inputs[i]->data_type() != input0_data_type) {
|
||||||
|
MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = "
|
||||||
|
<< inputs[i]->data_type();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
output_shape.insert(output_shape.begin() + axis, inputs.size());
|
output_shape.insert(output_shape.begin() + axis, inputs.size());
|
||||||
outputs[0]->set_shape(output_shape);
|
outputs[0]->set_shape(output_shape);
|
||||||
|
|
|
@ -49,12 +49,21 @@ int StackCPUKernel::Run() {
|
||||||
}
|
}
|
||||||
size_t inputs_num = in_tensors_.size();
|
size_t inputs_num = in_tensors_.size();
|
||||||
auto input0_shape = in_tensors_[0]->shape();
|
auto input0_shape = in_tensors_[0]->shape();
|
||||||
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
||||||
float *inputs[inputs_num];
|
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
|
||||||
for (size_t i = 0; i < inputs_num; ++i) {
|
float *inputs[inputs_num];
|
||||||
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
|
for (size_t i = 0; i < inputs_num; ++i) {
|
||||||
|
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
|
||||||
|
}
|
||||||
|
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
||||||
|
} else {
|
||||||
|
auto *output_data = reinterpret_cast<int32_t *>(out_tensors_[0]->Data());
|
||||||
|
int32_t *inputs[inputs_num];
|
||||||
|
for (size_t i = 0; i < inputs_num; ++i) {
|
||||||
|
inputs[i] = reinterpret_cast<int32_t *>(in_tensors_[i]->Data());
|
||||||
|
}
|
||||||
|
DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
||||||
}
|
}
|
||||||
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
Loading…
Reference in New Issue