forked from mindspore-Ecosystem/mindspore
!4541 [MS][LITE][Develop]optimize arithmetic fp16, gather and squezze support int32
Merge pull request !4541 from chenjianping/lite_dev2
This commit is contained in:
commit
2348e45815
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h"
|
#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h"
|
||||||
#include "src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h"
|
#include "src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h"
|
||||||
|
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
|
||||||
#include "schema/model_generated.h"
|
#include "schema/model_generated.h"
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "src/runtime/runtime_api.h"
|
#include "src/runtime/runtime_api.h"
|
||||||
|
@ -31,7 +32,7 @@ using mindspore::schema::PrimitiveType_Mul;
|
||||||
using mindspore::schema::PrimitiveType_Sub;
|
using mindspore::schema::PrimitiveType_Sub;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
void ArithmeticFP16CPUKernel::FreeTileData() {
|
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
|
||||||
if (tile_data0_ != nullptr) {
|
if (tile_data0_ != nullptr) {
|
||||||
free(tile_data0_);
|
free(tile_data0_);
|
||||||
tile_data0_ = nullptr;
|
tile_data0_ = nullptr;
|
||||||
|
@ -40,9 +41,21 @@ void ArithmeticFP16CPUKernel::FreeTileData() {
|
||||||
free(tile_data1_);
|
free(tile_data1_);
|
||||||
tile_data1_ = nullptr;
|
tile_data1_ = nullptr;
|
||||||
}
|
}
|
||||||
|
if (input0_fp16_ != nullptr) {
|
||||||
|
context_->allocator->Free(input0_fp16_);
|
||||||
|
input0_fp16_ = nullptr;
|
||||||
|
}
|
||||||
|
if (input1_fp16_ != nullptr) {
|
||||||
|
context_->allocator->Free(input1_fp16_);
|
||||||
|
input1_fp16_ = nullptr;
|
||||||
|
}
|
||||||
|
if (output_fp16_ != nullptr) {
|
||||||
|
context_->allocator->Free(output_fp16_);
|
||||||
|
output_fp16_ = nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTileData(); }
|
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); }
|
||||||
|
|
||||||
int ArithmeticFP16CPUKernel::Init() {
|
int ArithmeticFP16CPUKernel::Init() {
|
||||||
switch (op_parameter_->type_) {
|
switch (op_parameter_->type_) {
|
||||||
|
@ -97,10 +110,38 @@ int ArithmeticFP16CPUKernel::Init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticFP16CPUKernel::ReSize() {
|
int ArithmeticFP16CPUKernel::ReSize() {
|
||||||
FreeTileData();
|
FreeTmpBuffer();
|
||||||
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
||||||
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
||||||
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
||||||
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
||||||
|
input0_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
||||||
|
arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
||||||
|
if (input0_fp16_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc data fail!";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
|
||||||
|
arithmeticParameter_->in_elements_num0_);
|
||||||
|
}
|
||||||
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
|
||||||
|
input1_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
||||||
|
arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
|
||||||
|
if (input0_fp16_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc data fail!";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
|
||||||
|
arithmeticParameter_->in_elements_num1_);
|
||||||
|
}
|
||||||
|
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
|
||||||
|
output_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
||||||
|
arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
|
||||||
|
if (output_fp16_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "malloc data fail!";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
|
||||||
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
|
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
|
||||||
|
@ -137,13 +178,17 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
||||||
}
|
}
|
||||||
|
|
||||||
int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
|
int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
|
||||||
auto input0_data = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
|
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
|
||||||
auto input1_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
|
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
|
||||||
auto output_data = reinterpret_cast<float16_t *>(out_tensors_[0]->Data());
|
auto output = reinterpret_cast<float16_t *>(out_tensors_[0]->Data());
|
||||||
auto element_num = out_tensors_[0]->ElementsNum();
|
auto element_num = out_tensors_[0]->ElementsNum();
|
||||||
|
|
||||||
|
float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_;
|
||||||
|
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_;
|
||||||
|
auto output_data = output_fp16_ == nullptr ? output : output_fp16_;
|
||||||
int stride = UP_DIV(element_num, context_->thread_num_);
|
int stride = UP_DIV(element_num, context_->thread_num_);
|
||||||
int count = MSMIN(stride, element_num - stride * task_id);
|
int count = MSMIN(stride, element_num - stride * task_id);
|
||||||
|
auto thread_stride = stride * task_id;
|
||||||
|
|
||||||
if (arithmetic_run_ == nullptr) {
|
if (arithmetic_run_ == nullptr) {
|
||||||
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
||||||
|
@ -152,26 +197,30 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
|
||||||
|
|
||||||
int error_code = RET_OK;
|
int error_code = RET_OK;
|
||||||
if (arithmeticParameter_->broadcasting_) {
|
if (arithmeticParameter_->broadcasting_) {
|
||||||
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
|
error_code = arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride,
|
||||||
output_data + stride * task_id, count);
|
output_data + thread_stride, count);
|
||||||
} else if (arithmetic_opt_run_ != nullptr) {
|
} else if (arithmetic_opt_run_ != nullptr) {
|
||||||
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
||||||
error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id,
|
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride,
|
||||||
count, arithmeticParameter_);
|
count, arithmeticParameter_);
|
||||||
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
|
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
|
||||||
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id,
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride,
|
||||||
count, arithmeticParameter_);
|
count, arithmeticParameter_);
|
||||||
} else {
|
} else {
|
||||||
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride,
|
||||||
output_data + stride * task_id, count, arithmeticParameter_);
|
output_data + thread_stride, count, arithmeticParameter_);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
|
error_code = arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride,
|
||||||
output_data + stride * task_id, count);
|
output_data + thread_stride, count);
|
||||||
}
|
}
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
if (output_fp16_ != nullptr) {
|
||||||
|
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
|
||||||
|
Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count);
|
||||||
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -195,7 +244,9 @@ int ArithmeticFP16CPUKernel::Run() {
|
||||||
if (arithmeticParameter_->broadcasting_) {
|
if (arithmeticParameter_->broadcasting_) {
|
||||||
auto input_data0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
|
auto input_data0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
|
||||||
auto input_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
|
auto input_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
|
||||||
TileDimensionsFp16(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_);
|
float16_t *input0 = input0_fp16_ == nullptr ? input_data0 : input0_fp16_;
|
||||||
|
float16_t *input1 = input1_fp16_ == nullptr ? input_data1 : input1_fp16_;
|
||||||
|
TileDimensionsFp16(input0, input1, tile_data0_, tile_data1_, arithmeticParameter_);
|
||||||
}
|
}
|
||||||
ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
|
ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
|
|
|
@ -43,9 +43,12 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
|
||||||
int DoArithmetic(int task_id);
|
int DoArithmetic(int task_id);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void FreeTileData();
|
void FreeTmpBuffer();
|
||||||
float16_t *tile_data0_ = nullptr;
|
float16_t *tile_data0_ = nullptr;
|
||||||
float16_t *tile_data1_ = nullptr;
|
float16_t *tile_data1_ = nullptr;
|
||||||
|
float16_t *input0_fp16_ = nullptr;
|
||||||
|
float16_t *input1_fp16_ = nullptr;
|
||||||
|
float16_t *output_fp16_ = nullptr;
|
||||||
ArithmeticParameter *arithmeticParameter_ = nullptr;
|
ArithmeticParameter *arithmeticParameter_ = nullptr;
|
||||||
ArithmeticRun arithmetic_run_ = nullptr;
|
ArithmeticRun arithmetic_run_ = nullptr;
|
||||||
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
ArithmeticOptRun arithmetic_opt_run_ = nullptr;
|
||||||
|
|
|
@ -49,6 +49,9 @@ int GatherCPUKernel::DoGather(int task_id) {
|
||||||
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data());
|
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data());
|
||||||
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data());
|
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data());
|
||||||
|
|
||||||
|
auto input_int32 = reinterpret_cast<int32_t *>(input_tensor->Data());
|
||||||
|
auto output_int32 = reinterpret_cast<int32_t *>(out_tensor->Data());
|
||||||
|
|
||||||
auto in_shape = input_tensor->shape();
|
auto in_shape = input_tensor->shape();
|
||||||
int in_rank = in_shape.size();
|
int in_rank = in_shape.size();
|
||||||
int indices_element_size = indices_tensor->ElementsNum();
|
int indices_element_size = indices_tensor->ElementsNum();
|
||||||
|
@ -73,11 +76,19 @@ int GatherCPUKernel::DoGather(int task_id) {
|
||||||
|
|
||||||
int stride = UP_DIV(outer_size, thread_count_);
|
int stride = UP_DIV(outer_size, thread_count_);
|
||||||
int count = MSMIN(stride, outer_size - stride * task_id);
|
int count = MSMIN(stride, outer_size - stride * task_id);
|
||||||
|
auto thread_stride = stride * task_id;
|
||||||
|
|
||||||
input_ptr += stride * task_id * limit;
|
int error_code;
|
||||||
output_ptr += stride * task_id * indices_element_size;
|
if (input_tensor->data_type() == kNumberTypeInt32) {
|
||||||
|
input_int32 += thread_stride * limit;
|
||||||
|
output_int32 += thread_stride * indices_element_size;
|
||||||
|
error_code = GatherInt32(input_int32, count, inner_size, limit, indices_ptr, indices_element_size, output_int32);
|
||||||
|
} else {
|
||||||
|
input_ptr += thread_stride * limit;
|
||||||
|
output_ptr += thread_stride * indices_element_size;
|
||||||
|
error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
auto error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr);
|
|
||||||
if (error_code != RET_OK) {
|
if (error_code != RET_OK) {
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
|
@ -110,19 +121,21 @@ int GatherCPUKernel::Run() {
|
||||||
|
|
||||||
kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||||
OpParameter *opParameter, const lite::Context *ctx,
|
OpParameter *parameter, const lite::Context *ctx,
|
||||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||||
MS_ASSERT(opParameter != nullptr);
|
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_Gather);
|
MS_ASSERT(desc.type == schema::PrimitiveType_Gather);
|
||||||
|
if (parameter == nullptr) {
|
||||||
auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
MS_LOG(ERROR) << "input parameter is nullptr!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto *kernel = new (std::nothrow) GatherCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -130,4 +143,5 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::T
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -42,12 +42,20 @@ int SqueezeCPUKernel::Run() {
|
||||||
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t data_size = in_tensors_.front()->Size();
|
||||||
|
if (in_tensors_.front()->data_type() == kNumberTypeInt32) {
|
||||||
|
auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->Data());
|
||||||
|
auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->Data());
|
||||||
|
ret = DoSqueezeInt32(input_ptr, output_ptr, data_size);
|
||||||
|
} else {
|
||||||
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->Data());
|
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->Data());
|
||||||
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->Data());
|
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->Data());
|
||||||
size_t data_size = in_tensors_.front()->Size();
|
|
||||||
ret = DoSqueeze(input_ptr, output_ptr, data_size);
|
ret = DoSqueeze(input_ptr, output_ptr, data_size);
|
||||||
|
}
|
||||||
|
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Do squeeze failed.";
|
MS_LOG(ERROR) << "Do squeeze fail!ret: " << ret;
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
|
@ -55,14 +63,14 @@ int SqueezeCPUKernel::Run() {
|
||||||
|
|
||||||
kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||||
const std::vector<lite::tensor::Tensor *> &outputs,
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
||||||
OpParameter *opParameter, const lite::Context *ctx,
|
OpParameter *parameter, const lite::Context *ctx,
|
||||||
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
|
||||||
MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze);
|
MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze);
|
||||||
if (opParameter == nullptr) {
|
if (parameter == nullptr) {
|
||||||
MS_LOG(ERROR) << "desc type is not Squeeze";
|
MS_LOG(ERROR) << "desc type is not Squeeze";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
auto *kernel = new (std::nothrow) SqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
MS_LOG(ERROR) << "New kernel fails.";
|
MS_LOG(ERROR) << "New kernel fails.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -70,8 +78,8 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::
|
||||||
|
|
||||||
auto ret = kernel->Init();
|
auto ret = kernel->Init();
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
|
||||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||||
delete kernel;
|
delete kernel;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -80,4 +88,5 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
||||||
|
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
|
||||||
} // namespace mindspore::kernel
|
} // namespace mindspore::kernel
|
||||||
|
|
|
@ -100,17 +100,17 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = input0[0] * input1[0];
|
output[0] = input0[0] * input1[0];
|
||||||
output[1] = input0[1] * input1[1];
|
output[1] = input0[1] * input1[1];
|
||||||
output[2] = input0[2] * input1[2];
|
output[2] = input0[2] * input1[2];
|
||||||
output[3] = input0[3] * input1[3];
|
output[3] = input0[3] * input1[3];
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = input0[index] * input1[index];
|
output[index] = input0[index] * input1[index];
|
||||||
|
@ -120,10 +120,10 @@ int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
float16_t res = input0[0] * input1[0];
|
float16_t res = input0[0] * input1[0];
|
||||||
output[0] = res > 0 ? res : 0;
|
output[0] = res > 0 ? res : 0;
|
||||||
res = input0[1] * input1[1];
|
res = input0[1] * input1[1];
|
||||||
|
@ -132,9 +132,9 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
output[2] = res > 0 ? res : 0;
|
output[2] = res > 0 ? res : 0;
|
||||||
res = input0[3] * input1[3];
|
res = input0[3] * input1[3];
|
||||||
output[3] = res > 0 ? res : 0;
|
output[3] = res > 0 ? res : 0;
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
float16_t res = input0[index] * input1[index];
|
float16_t res = input0[index] * input1[index];
|
||||||
|
@ -145,17 +145,17 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6);
|
output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6);
|
||||||
output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6);
|
output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6);
|
||||||
output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6);
|
output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6);
|
||||||
output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6);
|
output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6);
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6);
|
output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6);
|
||||||
|
@ -165,17 +165,17 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = input0[0] + input1[0];
|
output[0] = input0[0] + input1[0];
|
||||||
output[1] = input0[1] + input1[1];
|
output[1] = input0[1] + input1[1];
|
||||||
output[2] = input0[2] + input1[2];
|
output[2] = input0[2] + input1[2];
|
||||||
output[3] = input0[3] + input1[3];
|
output[3] = input0[3] + input1[3];
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = input0[index] + input1[index];
|
output[index] = input0[index] + input1[index];
|
||||||
|
@ -184,10 +184,10 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
float16_t res = input0[0] + input1[0];
|
float16_t res = input0[0] + input1[0];
|
||||||
output[0] = res > 0 ? res : 0;
|
output[0] = res > 0 ? res : 0;
|
||||||
res = input0[1] + input1[1];
|
res = input0[1] + input1[1];
|
||||||
|
@ -196,9 +196,9 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
output[2] = res > 0 ? res : 0;
|
output[2] = res > 0 ? res : 0;
|
||||||
res = input0[3] + input1[3];
|
res = input0[3] + input1[3];
|
||||||
output[3] = res > 0 ? res : 0;
|
output[3] = res > 0 ? res : 0;
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
float16_t res = input0[index] + input1[index];
|
float16_t res = input0[index] + input1[index];
|
||||||
|
@ -208,17 +208,17 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6);
|
output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6);
|
||||||
output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6);
|
output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6);
|
||||||
output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6);
|
output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6);
|
||||||
output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6);
|
output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6);
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6);
|
output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6);
|
||||||
|
@ -228,17 +228,17 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = input0[0] - input1[0];
|
output[0] = input0[0] - input1[0];
|
||||||
output[1] = input0[1] - input1[1];
|
output[1] = input0[1] - input1[1];
|
||||||
output[2] = input0[2] - input1[2];
|
output[2] = input0[2] - input1[2];
|
||||||
output[3] = input0[3] - input1[3];
|
output[3] = input0[3] - input1[3];
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = input0[index] - input1[index];
|
output[index] = input0[index] - input1[index];
|
||||||
|
@ -247,10 +247,10 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
float16_t res = input0[0] - input1[0];
|
float16_t res = input0[0] - input1[0];
|
||||||
output[0] = res > 0 ? res : 0;
|
output[0] = res > 0 ? res : 0;
|
||||||
res = input0[1] - input1[1];
|
res = input0[1] - input1[1];
|
||||||
|
@ -259,9 +259,9 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
output[2] = res > 0 ? res : 0;
|
output[2] = res > 0 ? res : 0;
|
||||||
res = input0[3] - input1[3];
|
res = input0[3] - input1[3];
|
||||||
output[3] = res > 0 ? res : 0;
|
output[3] = res > 0 ? res : 0;
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
float16_t res = input0[index] - input1[index];
|
float16_t res = input0[index] - input1[index];
|
||||||
|
@ -271,17 +271,17 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
|
||||||
}
|
}
|
||||||
|
|
||||||
int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
|
||||||
int block_mod = element_size % C4NUM;
|
int block_mod = element_size % C8NUM;
|
||||||
int block_c4 = element_size - block_mod;
|
int block_c8 = element_size - block_mod;
|
||||||
|
|
||||||
for (int index = 0; index < block_c4; index += C4NUM) {
|
for (int index = 0; index < block_c8; index += C8NUM) {
|
||||||
output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6);
|
output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6);
|
||||||
output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6);
|
output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6);
|
||||||
output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6);
|
output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6);
|
||||||
output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6);
|
output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6);
|
||||||
input0 += C4NUM;
|
input0 += C8NUM;
|
||||||
input1 += C4NUM;
|
input1 += C8NUM;
|
||||||
output += C4NUM;
|
output += C8NUM;
|
||||||
}
|
}
|
||||||
for (int index = 0; index < block_mod; ++index) {
|
for (int index = 0; index < block_mod; ++index) {
|
||||||
output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6);
|
output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
#include "nnacl/fp32/gather.h"
|
#include "nnacl/fp32/gather.h"
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
inline int Stride(int *shape, int rank, int index) {
|
inline int Stride(int *shape, int rank, int index) {
|
||||||
int i, stride = 1;
|
int i, stride = 1;
|
||||||
|
@ -33,10 +34,26 @@ int Gather(float *input, int outer_size, int inner_size, int limit, int *indices
|
||||||
float *outputm = output + inner_size * m * indices_element_size;
|
float *outputm = output + inner_size * m * indices_element_size;
|
||||||
for (i = 0; i < indices_element_size; ++i) {
|
for (i = 0; i < indices_element_size; ++i) {
|
||||||
if (indices[i] < 0 || indices[i] > limit) {
|
if (indices[i] < 0 || indices[i] > limit) {
|
||||||
return -1;
|
return NNACL_ERR;
|
||||||
}
|
}
|
||||||
memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size);
|
memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return 0;
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
|
||||||
|
int indices_element_size, int32_t *output) {
|
||||||
|
int i, m;
|
||||||
|
for (m = 0; m < outer_size; ++m) {
|
||||||
|
const int32_t *inputm = input + inner_size * m * limit;
|
||||||
|
int32_t *outputm = output + inner_size * m * indices_element_size;
|
||||||
|
for (i = 0; i < indices_element_size; ++i) {
|
||||||
|
if (indices[i] < 0 || indices[i] > limit) {
|
||||||
|
return NNACL_ERR;
|
||||||
|
}
|
||||||
|
memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(int32_t) * inner_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,6 +30,8 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
|
||||||
float *output);
|
float *output);
|
||||||
|
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
|
||||||
|
int indices_element_size, int32_t *output);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -16,11 +16,20 @@
|
||||||
|
|
||||||
#include "nnacl/squeeze.h"
|
#include "nnacl/squeeze.h"
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include "nnacl/errorcode.h"
|
||||||
|
|
||||||
int DoSqueeze(float *in_data, float *out_data, size_t data_size) {
|
int DoSqueeze(float *in_data, float *out_data, size_t data_size) {
|
||||||
if (in_data == NULL || out_data == NULL) {
|
if (in_data == NULL || out_data == NULL) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
(void)memcpy(out_data, in_data, data_size);
|
(void)memcpy(out_data, in_data, data_size);
|
||||||
return 0;
|
return NNACL_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size) {
|
||||||
|
if (in_data == NULL || out_data == NULL) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
(void)memcpy(out_data, in_data, data_size);
|
||||||
|
return NNACL_OK;
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,6 +28,7 @@ typedef struct SqueezeParameter {
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size);
|
int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size);
|
||||||
|
int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size);
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Reference in New Issue