forked from mindspore-Ecosystem/mindspore
stack support int32
This commit is contained in:
parent
50877b586d
commit
babff262e3
|
@ -17,7 +17,7 @@
|
|||
#include "nnacl/fp32/stack.h"
|
||||
#include "nnacl/arithmetic_common.h"
|
||||
|
||||
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
|
||||
size_t GetStackCopyNum(int axis, int *in_shape, size_t shape_size) {
|
||||
size_t one_input_size = 1;
|
||||
for (size_t i = 0; i < shape_size; ++i) {
|
||||
one_input_size *= in_shape[i];
|
||||
|
@ -26,11 +26,37 @@ void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t
|
|||
ComputeStrides(in_shape, in_strides, shape_size);
|
||||
|
||||
size_t copy_num = axis > 0 ? in_strides[axis - 1] : one_input_size;
|
||||
size_t copy_size = copy_num * sizeof(float);
|
||||
return copy_num;
|
||||
}
|
||||
|
||||
size_t GetStackPreAxisCount(const int *in_shape, int axis) {
|
||||
size_t pre_axis_count = 1;
|
||||
for (size_t i = 0; i < axis; ++i) {
|
||||
pre_axis_count *= in_shape[i];
|
||||
}
|
||||
return pre_axis_count;
|
||||
}
|
||||
|
||||
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output) {
|
||||
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
|
||||
size_t copy_size = copy_num * sizeof(float);
|
||||
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
|
||||
size_t in_offset = 0;
|
||||
size_t out_offset = 0;
|
||||
for (size_t i = 0; i < pre_axis_count; ++i) {
|
||||
for (size_t j = 0; j < input_num; ++j) {
|
||||
memcpy(output + out_offset, inputs[j] + in_offset, copy_size);
|
||||
out_offset += copy_num;
|
||||
}
|
||||
in_offset += copy_num;
|
||||
}
|
||||
}
|
||||
|
||||
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
|
||||
int32_t *output) {
|
||||
size_t copy_num = GetStackCopyNum(axis, in_shape, shape_size);
|
||||
size_t copy_size = copy_num * sizeof(int32_t);
|
||||
size_t pre_axis_count = GetStackPreAxisCount(in_shape, axis);
|
||||
size_t in_offset = 0;
|
||||
size_t out_offset = 0;
|
||||
for (size_t i = 0; i < pre_axis_count; ++i) {
|
||||
|
|
|
@ -27,6 +27,8 @@ typedef struct StackParameter {
|
|||
extern "C" {
|
||||
#endif
|
||||
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
|
||||
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
|
||||
int32_t *output);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -56,7 +56,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input = inputs.at(0);
|
||||
outputs[0]->set_data_type(input->data_type());
|
||||
auto input0_data_type = input->data_type();
|
||||
outputs[0]->set_data_type(input0_data_type);
|
||||
outputs[0]->SetFormat(input->GetFormat());
|
||||
if (!GetInferFlag()) {
|
||||
return RET_OK;
|
||||
|
@ -69,12 +70,8 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
MS_LOG(ERROR) << "Invalid axis " << GetAxis();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
schema::Format input0_format = input->GetFormat();
|
||||
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (inputs[i]->GetFormat() != input0_format) {
|
||||
MS_LOG(ERROR) << "All inputs should have the same format!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto input_shape_tmp = inputs[i]->shape();
|
||||
if (input_shape_tmp.size() != input_shape.size()) {
|
||||
MS_LOG(ERROR) << "All input shape size should be the same!";
|
||||
|
@ -86,6 +83,11 @@ int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
if (inputs[i]->data_type() != input0_data_type) {
|
||||
MS_LOG(ERROR) << "All input shuld have the same data type!input[" << i << "] data type = "
|
||||
<< inputs[i]->data_type();
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
output_shape.insert(output_shape.begin() + axis, inputs.size());
|
||||
outputs[0]->set_shape(output_shape);
|
||||
|
|
|
@ -49,12 +49,21 @@ int StackCPUKernel::Run() {
|
|||
}
|
||||
size_t inputs_num = in_tensors_.size();
|
||||
auto input0_shape = in_tensors_[0]->shape();
|
||||
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
|
||||
float *inputs[inputs_num];
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
|
||||
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
||||
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());
|
||||
float *inputs[inputs_num];
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
inputs[i] = reinterpret_cast<float *>(in_tensors_[i]->Data());
|
||||
}
|
||||
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
||||
} else {
|
||||
auto *output_data = reinterpret_cast<int32_t *>(out_tensors_[0]->Data());
|
||||
int32_t *inputs[inputs_num];
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
inputs[i] = reinterpret_cast<int32_t *>(in_tensors_[i]->Data());
|
||||
}
|
||||
DoStackInt32(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
||||
}
|
||||
DoStack(inputs, inputs_num, input0_shape.data(), input0_shape.size(), axis_, output_data);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
@ -85,4 +94,5 @@ kernel::LiteKernel *CpuStackFp32KernelCreator(const std::vector<lite::tensor::Te
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Stack, CpuStackFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue