!5668 [MS][LITE][Develop]stack support one input

Merge pull request !5668 from chenjianping/lite_dev
This commit is contained in:
mindspore-ci-bot 2020-09-02 21:25:25 +08:00 committed by Gitee
commit 6b107f4412
4 changed files with 12 additions and 1 deletions

View File

@ -67,3 +67,7 @@ void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape,
in_offset += copy_num;
}
}
void DoStackOneInput(const int8_t *input, int8_t *output, size_t data_size) {
memcpy(output, input, data_size);
}

View File

@ -29,6 +29,7 @@ extern "C" {
void DoStack(const float *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis, float *output);
void DoStackInt32(const int32_t *const *inputs, size_t input_num, int *in_shape, size_t shape_size, int axis,
int32_t *output);
void DoStackOneInput(const int8_t *input, int8_t *output, size_t data_size);
#ifdef __cplusplus
}
#endif

View File

@ -58,7 +58,7 @@ int Stack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::
namespace {
constexpr int kStackOutputNum = 1;
constexpr int kStackMinInputNum = 2;
constexpr int kStackMinInputNum = 1;
} // namespace
int Stack::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) {
MS_ASSERT(this->primitive_ != nullptr);

View File

@ -48,6 +48,12 @@ int StackCPUKernel::Run() {
return ret;
}
size_t inputs_num = in_tensors_.size();
auto input0 = in_tensors_[0];
if (inputs_num == 1) {
auto *output_data = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
DoStackOneInput(reinterpret_cast<const int8_t *>(input0->Data()), output_data, input0->Size());
return RET_OK;
}
auto input0_shape = in_tensors_[0]->shape();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
auto *output_data = reinterpret_cast<float *>(out_tensors_[0]->Data());