diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 5767491aa85..4ee99d86cae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h" #include "src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h" +#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "src/runtime/runtime_api.h" @@ -31,7 +32,7 @@ using mindspore::schema::PrimitiveType_Mul; using mindspore::schema::PrimitiveType_Sub; namespace mindspore::kernel { -void ArithmeticFP16CPUKernel::FreeTileData() { +void ArithmeticFP16CPUKernel::FreeTmpBuffer() { if (tile_data0_ != nullptr) { free(tile_data0_); tile_data0_ = nullptr; @@ -40,9 +41,21 @@ void ArithmeticFP16CPUKernel::FreeTileData() { free(tile_data1_); tile_data1_ = nullptr; } + if (input0_fp16_ != nullptr) { + context_->allocator->Free(input0_fp16_); + input0_fp16_ = nullptr; + } + if (input1_fp16_ != nullptr) { + context_->allocator->Free(input1_fp16_); + input1_fp16_ = nullptr; + } + if (output_fp16_ != nullptr) { + context_->allocator->Free(output_fp16_); + output_fp16_ = nullptr; + } } -ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTileData(); } +ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); } int ArithmeticFP16CPUKernel::Init() { switch (op_parameter_->type_) { @@ -97,10 +110,38 @@ int ArithmeticFP16CPUKernel::Init() { } int ArithmeticFP16CPUKernel::ReSize() { - FreeTileData(); + FreeTmpBuffer(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { + input0_fp16_ = reinterpret_cast(context_->allocator->Malloc( + arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); + if (input0_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + return RET_ERROR; + } + Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, + arithmeticParameter_->in_elements_num0_); + } + if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { + input1_fp16_ = reinterpret_cast(context_->allocator->Malloc( + arithmeticParameter_->in_elements_num1_ * sizeof(float16_t))); + if (input0_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + return RET_ERROR; + } + Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, + arithmeticParameter_->in_elements_num1_); + } + if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { + output_fp16_ = reinterpret_cast(context_->allocator->Malloc( + arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); + if (output_fp16_ == nullptr) { + MS_LOG(ERROR) << "malloc data fail!"; + return RET_ERROR; + } + } if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) { @@ -137,13 +178,17 @@ int ArithmeticFP16CPUKernel::ReSize() { } int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { - auto input0_data = reinterpret_cast(in_tensors_[0]->Data()); - auto input1_data1 = reinterpret_cast(in_tensors_[1]->Data()); - auto output_data = reinterpret_cast(out_tensors_[0]->Data()); + auto input0 = reinterpret_cast(in_tensors_[0]->Data()); + auto input1 = reinterpret_cast(in_tensors_[1]->Data()); + auto output = reinterpret_cast(out_tensors_[0]->Data()); auto element_num = out_tensors_[0]->ElementsNum(); + float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_; + float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_; + auto output_data = output_fp16_ == nullptr ? output : output_fp16_; int stride = UP_DIV(element_num, context_->thread_num_); int count = MSMIN(stride, element_num - stride * task_id); + auto thread_stride = stride * task_id; if (arithmetic_run_ == nullptr) { MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; @@ -152,26 +197,30 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, - output_data + stride * task_id, count); + error_code = arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, + output_data + thread_stride, count); } else if (arithmetic_opt_run_ != nullptr) { if (arithmeticParameter_->in_elements_num0_ == 1) { - error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id, + error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, arithmeticParameter_); } else if (arithmeticParameter_->in_elements_num1_ == 1) { - error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id, + error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride, count, arithmeticParameter_); } else { - error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, - output_data + stride * task_id, count, arithmeticParameter_); + error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride, + output_data + thread_stride, count, arithmeticParameter_); } } else { - error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id, - output_data + stride * task_id, count); + error_code = arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, + output_data + thread_stride, count); } if (error_code != RET_OK) { return RET_ERROR; } + if (output_fp16_ != nullptr) { + auto output_fp32 = reinterpret_cast(out_tensors_[0]->Data()); + Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count); + } return RET_OK; } @@ -195,7 +244,9 @@ int ArithmeticFP16CPUKernel::Run() { if (arithmeticParameter_->broadcasting_) { auto input_data0 = reinterpret_cast(in_tensors_[0]->Data()); auto input_data1 = reinterpret_cast(in_tensors_[1]->Data()); - TileDimensionsFp16(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_); + float16_t *input0 = input0_fp16_ == nullptr ? input_data0 : input0_fp16_; + float16_t *input1 = input1_fp16_ == nullptr ? input_data1 : input1_fp16_; + TileDimensionsFp16(input0, input1, tile_data0_, tile_data1_, arithmeticParameter_); } ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index af589f7ae2e..8a58bb08802 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -43,9 +43,12 @@ class ArithmeticFP16CPUKernel : public LiteKernel { int DoArithmetic(int task_id); private: - void FreeTileData(); + void FreeTmpBuffer(); float16_t *tile_data0_ = nullptr; float16_t *tile_data1_ = nullptr; + float16_t *input0_fp16_ = nullptr; + float16_t *input1_fp16_ = nullptr; + float16_t *output_fp16_ = nullptr; ArithmeticParameter *arithmeticParameter_ = nullptr; ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc index 46587db0f14..9a7e1a9a0d1 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/gather.cc @@ -49,6 +49,9 @@ int GatherCPUKernel::DoGather(int task_id) { auto indices_ptr = reinterpret_cast(indices_tensor->Data()); auto output_ptr = reinterpret_cast(out_tensor->Data()); + auto input_int32 = reinterpret_cast(input_tensor->Data()); + auto output_int32 = reinterpret_cast(out_tensor->Data()); + auto in_shape = input_tensor->shape(); int in_rank = in_shape.size(); int indices_element_size = indices_tensor->ElementsNum(); @@ -73,11 +76,19 @@ int GatherCPUKernel::DoGather(int task_id) { int stride = UP_DIV(outer_size, thread_count_); int count = MSMIN(stride, outer_size - stride * task_id); + auto thread_stride = stride * task_id; - input_ptr += stride * task_id * limit; - output_ptr += stride * task_id * indices_element_size; + int error_code; + if (input_tensor->data_type() == kNumberTypeInt32) { + input_int32 += thread_stride * limit; + output_int32 += thread_stride * indices_element_size; + error_code = GatherInt32(input_int32, count, inner_size, limit, indices_ptr, indices_element_size, output_int32); + } else { + input_ptr += thread_stride * limit; + output_ptr += thread_stride * indices_element_size; + error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); + } - auto error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr); if (error_code != RET_OK) { return RET_ERROR; } @@ -110,19 +121,21 @@ int GatherCPUKernel::Run() { kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, + OpParameter *parameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { - MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_Gather); - - auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx, primitive); + if (parameter == nullptr) { + MS_LOG(ERROR) << "input parameter is nullptr!"; + return nullptr; + } + auto *kernel = new (std::nothrow) GatherCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { return nullptr; } auto ret = kernel->Init(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); delete kernel; return nullptr; } @@ -130,4 +143,5 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector(in_tensors_.front()->Data()); - auto output_ptr = reinterpret_cast(out_tensors_.front()->Data()); + size_t data_size = in_tensors_.front()->Size(); - ret = DoSqueeze(input_ptr, output_ptr, data_size); + if (in_tensors_.front()->data_type() == kNumberTypeInt32) { + auto input_ptr = reinterpret_cast(in_tensors_.front()->Data()); + auto output_ptr = reinterpret_cast(out_tensors_.front()->Data()); + ret = DoSqueezeInt32(input_ptr, output_ptr, data_size); + } else { + auto input_ptr = reinterpret_cast(in_tensors_.front()->Data()); + auto output_ptr = reinterpret_cast(out_tensors_.front()->Data()); + ret = DoSqueeze(input_ptr, output_ptr, data_size); + } + if (ret != RET_OK) { - MS_LOG(ERROR) << "Do squeeze failed."; + MS_LOG(ERROR) << "Do squeeze fail!ret: " << ret; return RET_ERROR; } return RET_OK; @@ -55,14 +63,14 @@ int SqueezeCPUKernel::Run() { kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, - OpParameter *opParameter, const lite::Context *ctx, + OpParameter *parameter, const lite::Context *ctx, const kernel::KernelKey &desc, const lite::Primitive *primitive) { MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze); - if (opParameter == nullptr) { + if (parameter == nullptr) { MS_LOG(ERROR) << "desc type is not Squeeze"; return nullptr; } - auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto *kernel = new (std::nothrow) SqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive); if (kernel == nullptr) { MS_LOG(ERROR) << "New kernel fails."; return nullptr; @@ -70,8 +78,8 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vectorInit(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " - << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); + MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: " + << schema::EnumNamePrimitiveType(static_cast(parameter->type_)); delete kernel; return nullptr; } @@ -80,4 +88,5 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector 0 ? res : 0; res = input0[1] * input1[1]; @@ -132,9 +132,9 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, output[2] = res > 0 ? res : 0; res = input0[3] * input1[3]; output[3] = res > 0 ? res : 0; - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { float16_t res = input0[index] * input1[index]; @@ -145,17 +145,17 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, } int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6); output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6); output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6); output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6); - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6); @@ -165,17 +165,17 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, } int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { output[0] = input0[0] + input1[0]; output[1] = input0[1] + input1[1]; output[2] = input0[2] + input1[2]; output[3] = input0[3] + input1[3]; - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { output[index] = input0[index] + input1[index]; @@ -184,10 +184,10 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int } int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { float16_t res = input0[0] + input1[0]; output[0] = res > 0 ? res : 0; res = input0[1] + input1[1]; @@ -196,9 +196,9 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, output[2] = res > 0 ? res : 0; res = input0[3] + input1[3]; output[3] = res > 0 ? res : 0; - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { float16_t res = input0[index] + input1[index]; @@ -208,17 +208,17 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, } int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6); output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6); output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6); output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6); - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6); @@ -228,17 +228,17 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, } int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { output[0] = input0[0] - input1[0]; output[1] = input0[1] - input1[1]; output[2] = input0[2] - input1[2]; output[3] = input0[3] - input1[3]; - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { output[index] = input0[index] - input1[index]; @@ -247,10 +247,10 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int } int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { float16_t res = input0[0] - input1[0]; output[0] = res > 0 ? res : 0; res = input0[1] - input1[1]; @@ -259,9 +259,9 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, output[2] = res > 0 ? res : 0; res = input0[3] - input1[3]; output[3] = res > 0 ? res : 0; - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { float16_t res = input0[index] - input1[index]; @@ -271,17 +271,17 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, } int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { - int block_mod = element_size % C4NUM; - int block_c4 = element_size - block_mod; + int block_mod = element_size % C8NUM; + int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c4; index += C4NUM) { + for (int index = 0; index < block_c8; index += C8NUM) { output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6); output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6); output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6); output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6); - input0 += C4NUM; - input1 += C4NUM; - output += C4NUM; + input0 += C8NUM; + input1 += C8NUM; + output += C8NUM; } for (int index = 0; index < block_mod; ++index) { output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.c index 59ecac4ca11..cbd7cf90d71 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.c @@ -16,6 +16,7 @@ #include "nnacl/fp32/gather.h" #include +#include "nnacl/errorcode.h" inline int Stride(int *shape, int rank, int index) { int i, stride = 1; @@ -33,10 +34,26 @@ int Gather(float *input, int outer_size, int inner_size, int limit, int *indices float *outputm = output + inner_size * m * indices_element_size; for (i = 0; i < indices_element_size; ++i) { if (indices[i] < 0 || indices[i] > limit) { - return -1; + return NNACL_ERR; } memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size); } } - return 0; + return NNACL_OK; +} + +int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, + int indices_element_size, int32_t *output) { + int i, m; + for (m = 0; m < outer_size; ++m) { + const int32_t *inputm = input + inner_size * m * limit; + int32_t *outputm = output + inner_size * m * indices_element_size; + for (i = 0; i < indices_element_size; ++i) { + if (indices[i] < 0 || indices[i] > limit) { + return NNACL_ERR; + } + memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(int32_t) * inner_size); + } + } + return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h index 90352aab944..c94ff03c598 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/gather.h @@ -30,6 +30,8 @@ extern "C" { #endif int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size, float *output); +int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices, + int indices_element_size, int32_t *output); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.c index 0ab52d1e0a3..b02e6408f22 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.c @@ -16,11 +16,20 @@ #include "nnacl/squeeze.h" #include +#include "nnacl/errorcode.h" int DoSqueeze(float *in_data, float *out_data, size_t data_size) { if (in_data == NULL || out_data == NULL) { return -1; } (void)memcpy(out_data, in_data, data_size); - return 0; + return NNACL_OK; +} + +int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size) { + if (in_data == NULL || out_data == NULL) { + return -1; + } + (void)memcpy(out_data, in_data, data_size); + return NNACL_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h index 6c1863e4af2..71345f20fa7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/squeeze.h @@ -28,6 +28,7 @@ typedef struct SqueezeParameter { extern "C" { #endif int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size); +int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size); #ifdef __cplusplus } #endif