arithemetic support fp16,gather, squezze support int32

This commit is contained in:
chenjianping 2020-08-16 18:21:52 +08:00
parent 125d021e1a
commit 3b1a048677
9 changed files with 197 additions and 91 deletions

View File

@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/fp16/arithmetic_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h"
#include "src/runtime/kernel/arm/nnacl/fp16/cast_fp16.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
@ -31,7 +32,7 @@ using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_Sub;
namespace mindspore::kernel {
void ArithmeticFP16CPUKernel::FreeTileData() {
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (tile_data0_ != nullptr) {
free(tile_data0_);
tile_data0_ = nullptr;
@ -40,9 +41,21 @@ void ArithmeticFP16CPUKernel::FreeTileData() {
free(tile_data1_);
tile_data1_ = nullptr;
}
if (input0_fp16_ != nullptr) {
context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr;
}
if (input1_fp16_ != nullptr) {
context_->allocator->Free(input1_fp16_);
input1_fp16_ = nullptr;
}
if (output_fp16_ != nullptr) {
context_->allocator->Free(output_fp16_);
output_fp16_ = nullptr;
}
}
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTileData(); }
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); }
int ArithmeticFP16CPUKernel::Init() {
switch (op_parameter_->type_) {
@ -97,10 +110,38 @@ int ArithmeticFP16CPUKernel::Init() {
}
int ArithmeticFP16CPUKernel::ReSize() {
FreeTileData();
FreeTmpBuffer();
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[0]->Data()), input0_fp16_,
arithmeticParameter_->in_elements_num0_);
}
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
Float32ToFloat16(reinterpret_cast<float *>(in_tensors_[1]->Data()), input1_fp16_,
arithmeticParameter_->in_elements_num1_);
}
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!";
return RET_ERROR;
}
}
if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) {
if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) {
@ -137,13 +178,17 @@ int ArithmeticFP16CPUKernel::ReSize() {
}
int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
auto input0_data = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input1_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
auto output_data = reinterpret_cast<float16_t *>(out_tensors_[0]->Data());
auto input0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
auto output = reinterpret_cast<float16_t *>(out_tensors_[0]->Data());
auto element_num = out_tensors_[0]->ElementsNum();
float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_;
float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_;
auto output_data = output_fp16_ == nullptr ? output : output_fp16_;
int stride = UP_DIV(element_num, context_->thread_num_);
int count = MSMIN(stride, element_num - stride * task_id);
auto thread_stride = stride * task_id;
if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
@ -152,26 +197,30 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
int error_code = RET_OK;
if (arithmeticParameter_->broadcasting_) {
error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id,
output_data + stride * task_id, count);
error_code = arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride,
output_data + thread_stride, count);
} else if (arithmetic_opt_run_ != nullptr) {
if (arithmeticParameter_->in_elements_num0_ == 1) {
error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id,
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride,
count, arithmeticParameter_);
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1, output_data + stride * task_id,
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride,
count, arithmeticParameter_);
} else {
error_code = arithmetic_opt_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count, arithmeticParameter_);
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride,
output_data + thread_stride, count, arithmeticParameter_);
}
} else {
error_code = arithmetic_run_(input0_data + stride * task_id, input1_data1 + stride * task_id,
output_data + stride * task_id, count);
error_code = arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride,
output_data + thread_stride, count);
}
if (error_code != RET_OK) {
return RET_ERROR;
}
if (output_fp16_ != nullptr) {
auto output_fp32 = reinterpret_cast<float *>(out_tensors_[0]->Data());
Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count);
}
return RET_OK;
}
@ -195,7 +244,9 @@ int ArithmeticFP16CPUKernel::Run() {
if (arithmeticParameter_->broadcasting_) {
auto input_data0 = reinterpret_cast<float16_t *>(in_tensors_[0]->Data());
auto input_data1 = reinterpret_cast<float16_t *>(in_tensors_[1]->Data());
TileDimensionsFp16(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_);
float16_t *input0 = input0_fp16_ == nullptr ? input_data0 : input0_fp16_;
float16_t *input1 = input1_fp16_ == nullptr ? input_data1 : input1_fp16_;
TileDimensionsFp16(input0, input1, tile_data0_, tile_data1_, arithmeticParameter_);
}
ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_);
if (ret != RET_OK) {

View File

@ -43,9 +43,12 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int DoArithmetic(int task_id);
private:
void FreeTileData();
void FreeTmpBuffer();
float16_t *tile_data0_ = nullptr;
float16_t *tile_data1_ = nullptr;
float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr;
ArithmeticParameter *arithmeticParameter_ = nullptr;
ArithmeticRun arithmetic_run_ = nullptr;
ArithmeticOptRun arithmetic_opt_run_ = nullptr;

View File

@ -49,6 +49,9 @@ int GatherCPUKernel::DoGather(int task_id) {
auto indices_ptr = reinterpret_cast<int *>(indices_tensor->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensor->Data());
auto input_int32 = reinterpret_cast<int32_t *>(input_tensor->Data());
auto output_int32 = reinterpret_cast<int32_t *>(out_tensor->Data());
auto in_shape = input_tensor->shape();
int in_rank = in_shape.size();
int indices_element_size = indices_tensor->ElementsNum();
@ -73,11 +76,19 @@ int GatherCPUKernel::DoGather(int task_id) {
int stride = UP_DIV(outer_size, thread_count_);
int count = MSMIN(stride, outer_size - stride * task_id);
auto thread_stride = stride * task_id;
input_ptr += stride * task_id * limit;
output_ptr += stride * task_id * indices_element_size;
int error_code;
if (input_tensor->data_type() == kNumberTypeInt32) {
input_int32 += thread_stride * limit;
output_int32 += thread_stride * indices_element_size;
error_code = GatherInt32(input_int32, count, inner_size, limit, indices_ptr, indices_element_size, output_int32);
} else {
input_ptr += thread_stride * limit;
output_ptr += thread_stride * indices_element_size;
error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr);
}
auto error_code = Gather(input_ptr, count, inner_size, limit, indices_ptr, indices_element_size, output_ptr);
if (error_code != RET_OK) {
return RET_ERROR;
}
@ -110,19 +121,21 @@ int GatherCPUKernel::Run() {
kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
OpParameter *parameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_Gather);
auto *kernel = new (std::nothrow) GatherCPUKernel(opParameter, inputs, outputs, ctx, primitive);
if (parameter == nullptr) {
MS_LOG(ERROR) << "input parameter is nullptr!";
return nullptr;
}
auto *kernel = new (std::nothrow) GatherCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
@ -130,4 +143,5 @@ kernel::LiteKernel *CpuGatherFp32KernelCreator(const std::vector<lite::tensor::T
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, CpuGatherFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -42,12 +42,20 @@ int SqueezeCPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->Data());
size_t data_size = in_tensors_.front()->Size();
ret = DoSqueeze(input_ptr, output_ptr, data_size);
if (in_tensors_.front()->data_type() == kNumberTypeInt32) {
auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->Data());
auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->Data());
ret = DoSqueezeInt32(input_ptr, output_ptr, data_size);
} else {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->Data());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->Data());
ret = DoSqueeze(input_ptr, output_ptr, data_size);
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "Do squeeze failed.";
MS_LOG(ERROR) << "Do squeeze fail!ret: " << ret;
return RET_ERROR;
}
return RET_OK;
@ -55,14 +63,14 @@ int SqueezeCPUKernel::Run() {
kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs,
OpParameter *opParameter, const lite::Context *ctx,
OpParameter *parameter, const lite::Context *ctx,
const kernel::KernelKey &desc, const lite::Primitive *primitive) {
MS_ASSERT(desc.type == schema::PrimitiveType_Squeeze);
if (opParameter == nullptr) {
if (parameter == nullptr) {
MS_LOG(ERROR) << "desc type is not Squeeze";
return nullptr;
}
auto *kernel = new (std::nothrow) SqueezeCPUKernel(opParameter, inputs, outputs, ctx, primitive);
auto *kernel = new (std::nothrow) SqueezeCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "New kernel fails.";
return nullptr;
@ -70,8 +78,8 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
@ -80,4 +88,5 @@ kernel::LiteKernel *CpuSqueezeFp32KernelCreator(const std::vector<lite::tensor::
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Squeeze, CpuSqueezeFp32KernelCreator)
} // namespace mindspore::kernel

View File

@ -100,17 +100,17 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i
}
int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = input0[0] * input1[0];
output[1] = input0[1] * input1[1];
output[2] = input0[2] * input1[2];
output[3] = input0[3] * input1[3];
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] * input1[index];
@ -120,10 +120,10 @@ int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int
}
int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
float16_t res = input0[0] * input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] * input1[1];
@ -132,9 +132,9 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
output[2] = res > 0 ? res : 0;
res = input0[3] * input1[3];
output[3] = res > 0 ? res : 0;
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] * input1[index];
@ -145,17 +145,17 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
}
int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = MSMIN(MSMAX(input0[0] * input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] * input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] * input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] * input1[3], 0), 6);
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] * input1[index], 0), 6);
@ -165,17 +165,17 @@ int ElementMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
}
int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = input0[0] + input1[0];
output[1] = input0[1] + input1[1];
output[2] = input0[2] + input1[2];
output[3] = input0[3] + input1[3];
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] + input1[index];
@ -184,10 +184,10 @@ int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int
}
int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
float16_t res = input0[0] + input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] + input1[1];
@ -196,9 +196,9 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
output[2] = res > 0 ? res : 0;
res = input0[3] + input1[3];
output[3] = res > 0 ? res : 0;
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] + input1[index];
@ -208,17 +208,17 @@ int ElementAddReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
}
int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = MSMIN(MSMAX(input0[0] + input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] + input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] + input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] + input1[3], 0), 6);
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] + input1[index], 0), 6);
@ -228,17 +228,17 @@ int ElementAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output,
}
int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = input0[0] - input1[0];
output[1] = input0[1] - input1[1];
output[2] = input0[2] - input1[2];
output[3] = input0[3] - input1[3];
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = input0[index] - input1[index];
@ -247,10 +247,10 @@ int ElementSubFp16(float16_t *input0, float16_t *input1, float16_t *output, int
}
int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
float16_t res = input0[0] - input1[0];
output[0] = res > 0 ? res : 0;
res = input0[1] - input1[1];
@ -259,9 +259,9 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
output[2] = res > 0 ? res : 0;
res = input0[3] - input1[3];
output[3] = res > 0 ? res : 0;
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
float16_t res = input0[index] - input1[index];
@ -271,17 +271,17 @@ int ElementSubReluFp16(float16_t *input0, float16_t *input1, float16_t *output,
}
int ElementSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) {
int block_mod = element_size % C4NUM;
int block_c4 = element_size - block_mod;
int block_mod = element_size % C8NUM;
int block_c8 = element_size - block_mod;
for (int index = 0; index < block_c4; index += C4NUM) {
for (int index = 0; index < block_c8; index += C8NUM) {
output[0] = MSMIN(MSMAX(input0[0] - input1[0], 0), 6);
output[1] = MSMIN(MSMAX(input0[1] - input1[1], 0), 6);
output[2] = MSMIN(MSMAX(input0[2] - input1[2], 0), 6);
output[3] = MSMIN(MSMAX(input0[3] - input1[3], 0), 6);
input0 += C4NUM;
input1 += C4NUM;
output += C4NUM;
input0 += C8NUM;
input1 += C8NUM;
output += C8NUM;
}
for (int index = 0; index < block_mod; ++index) {
output[index] = MSMIN(MSMAX(input0[index] - input1[index], 0), 6);

View File

@ -16,6 +16,7 @@
#include "nnacl/fp32/gather.h"
#include <string.h>
#include "nnacl/errorcode.h"
inline int Stride(int *shape, int rank, int index) {
int i, stride = 1;
@ -33,10 +34,26 @@ int Gather(float *input, int outer_size, int inner_size, int limit, int *indices
float *outputm = output + inner_size * m * indices_element_size;
for (i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] > limit) {
return -1;
return NNACL_ERR;
}
memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(float) * inner_size);
}
}
return 0;
return NNACL_OK;
}
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, int32_t *output) {
int i, m;
for (m = 0; m < outer_size; ++m) {
const int32_t *inputm = input + inner_size * m * limit;
int32_t *outputm = output + inner_size * m * indices_element_size;
for (i = 0; i < indices_element_size; ++i) {
if (indices[i] < 0 || indices[i] > limit) {
return NNACL_ERR;
}
memcpy(outputm + i * inner_size, inputm + indices[i] * inner_size, sizeof(int32_t) * inner_size);
}
}
return NNACL_OK;
}

View File

@ -30,6 +30,8 @@ extern "C" {
#endif
int Gather(float *input, int outer_size, int inner_size, int limit, int *indices, int indices_element_size,
float *output);
int GatherInt32(const int32_t *input, int outer_size, int inner_size, int limit, int *indices,
int indices_element_size, int32_t *output);
#ifdef __cplusplus
}
#endif

View File

@ -16,11 +16,20 @@
#include "nnacl/squeeze.h"
#include <string.h>
#include "nnacl/errorcode.h"
int DoSqueeze(float *in_data, float *out_data, size_t data_size) {
if (in_data == NULL || out_data == NULL) {
return -1;
}
(void)memcpy(out_data, in_data, data_size);
return 0;
return NNACL_OK;
}
int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size) {
if (in_data == NULL || out_data == NULL) {
return -1;
}
(void)memcpy(out_data, in_data, data_size);
return NNACL_OK;
}

View File

@ -28,6 +28,7 @@ typedef struct SqueezeParameter {
extern "C" {
#endif
int DoSqueeze(float *input_ptr, float *output_ptr, size_t data_size);
int DoSqueezeInt32(int32_t *in_data, int32_t *out_data, size_t data_size);
#ifdef __cplusplus
}
#endif