!11709 [MSLITE] fp16 argmin max

From: @ling_qiao_min
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-02 17:23:53 +08:00 committed by Gitee
commit 9557bef491
8 changed files with 324 additions and 31 deletions

View File

@ -17,14 +17,22 @@
#ifndef MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_
#ifdef ENABLE_ARM64
#include <arm_neon.h>
#endif
#include "nnacl/op_base.h"
typedef int (*COMPARE_FUNCTION)(const void *a, const void *b);
typedef struct ArgElement {
uint32_t index_;
union ArgData {
int8_t i8_data_;
int32_t i_data_;
float f_data_;
#ifdef ENABLE_ARM64
float16_t f16_data_;
#endif
} data_;
} ArgElement;

View File

@ -0,0 +1,240 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/fp16/arg_min_max_fp16.h"
int ArgCompareAscFp16(const void *a, const void *b) {
float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
if (b_value > a_value) {
return -1;
}
if (b_value < a_value) {
return 1;
}
return 0;
}
int ArgCompareDescFp16(const void *a, const void *b) {
float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
if (b_value > a_value) {
return 1;
}
if (b_value < a_value) {
return -1;
}
return 0;
}
void ArgMaxTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value,
const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
float16_t value = -FLT_MAX;
float16_t index = 0.0f;
for (int k = 0; k < axis_count; ++k) {
float16_t value_tmp = input[input_offset + k * after_axis_count + j];
if (value_tmp > value) {
value = value_tmp;
index = k;
}
}
output[output_offset + j] = out_value ? value : index;
if (output_value != NULL) {
output_value[output_offset + j] = value;
}
}
}
}
void ArgMinTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value,
const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) {
bool out_value = param->out_value_;
for (int i = 0; i < pre_axis_count; ++i) {
size_t output_offset = i * after_axis_count;
size_t input_offset = output_offset * axis_count;
for (int j = 0; j < after_axis_count; ++j) {
float16_t value = FLT_MAX;
float16_t index = 0.0f;
for (int k = 0; k < axis_count; ++k) {
float16_t value_tmp = input[input_offset + k * after_axis_count + j];
if (value_tmp < value) {
value = value_tmp;
index = k;
}
}
output[output_offset + j] = out_value ? value : index;
if (output_value != NULL) {
output_value[output_offset + j] = value;
}
}
}
}
void ArgMinMaxDim0Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
for (int j = 0; j < in_shape[0]; ++j) {
size_t offset = param->in_strides_[0] * j + i;
param->arg_elements_[j].index_ = j;
param->arg_elements_[j].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func);
for (int j = 0; j < param->topk_; ++j) {
size_t out_offset = j * param->out_strides_[0] + i;
output[out_offset] = param->out_value_ ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_;
if (output_value != NULL) {
output_value[out_offset] = param->arg_elements_[j].data_.f_data_;
}
}
}
return;
}
void ArgMinMaxDim1Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
int in_shape1 = in_shape[1];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < param->in_strides_[1]; ++j) {
for (int k = 0; k < in_shape1; ++k) {
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
param->arg_elements_[k].index_ = k;
param->arg_elements_[k].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func);
for (int k = 0; k < param->topk_; ++k) {
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
output[out_offset] = param->out_value_ ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_;
if (output_value != NULL) {
output_value[out_offset] = param->arg_elements_[k].data_.f_data_;
}
}
}
}
return;
}
void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < param->in_strides_[2]; ++k) {
for (int l = 0; l < in_shape2; ++l) {
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
output[out_offset] =
param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
if (output_value != NULL) {
output_value[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
}
void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
int in_shape1 = in_shape[1];
int in_shape2 = in_shape[2];
int in_shape3 = in_shape[3];
for (int i = 0; i < in_shape[0]; ++i) {
size_t in_dim0_offset = i * param->in_strides_[0];
size_t out_dim0_offset = i * param->out_strides_[0];
for (int j = 0; j < in_shape1; ++j) {
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
for (int k = 0; k < in_shape2; ++k) {
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
for (int l = 0; l < in_shape3; ++l) {
size_t offset = l + in_dim2_offset;
param->arg_elements_[l].index_ = l;
param->arg_elements_[l].data_.f_data_ = input[offset];
}
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func);
for (int l = 0; l < param->topk_; ++l) {
size_t out_offset = out_dim2_offset + l;
output[out_offset] =
param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
if (output_value != NULL) {
output_value[out_offset] = param->arg_elements_[l].data_.f_data_;
}
}
}
}
}
}
void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param) {
if (param->topk_ == 1) {
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
if (param->get_max_) {
ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
} else {
ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
}
return;
}
COMPARE_FUNCTION compare_function = NULL;
if (param->get_max_) {
compare_function = ArgCompareDescFp16;
} else {
compare_function = ArgCompareAscFp16;
}
switch (param->axis_) {
case 0:
ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function);
break;
case 1:
ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function);
break;
case 2:
ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function);
break;
case 3:
ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function);
break;
}
return;
}

View File

@ -0,0 +1,32 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_
#define MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_
#include <float.h>
#include "nnacl/arg_min_max_parameter.h"
#include "nnacl/nnacl_common.h"
#ifdef __cplusplus
extern "C" {
#endif
void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
const ArgMinMaxParameter *param);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_

View File

@ -91,21 +91,6 @@ void ArgMinTopK1(const float *input, float *output, float *output_value, const A
}
}
void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count,
int *after_axis_count) {
*pre_axis_count = 1;
for (int i = 0; i < axis; ++i) {
*pre_axis_count = (*pre_axis_count) * shape[i];
}
*axis_count = shape[axis];
*after_axis_count = 1;
for (int i = axis + 1; i < dims_number; ++i) {
*after_axis_count = (*after_axis_count) * shape[i];
}
}
void ArgMinMaxDim0(const float *input, float *output, float *output_value, const int *in_shape,
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
@ -221,7 +206,7 @@ void ArgMinMaxFp32(const float *input, float *output, float *output_value, const
int pre_axis_count = 1;
int axis_count = 1;
int after_axis_count = 1;
GetCalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
if (param->get_max_) {
ArgMaxTopK1(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);

View File

@ -16,10 +16,9 @@
#ifndef MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_
#define MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_
#include "nnacl/nnacl_common.h"
#include "nnacl/arg_min_max_parameter.h"
typedef int (*COMPARE_FUNCTION)(const void *a, const void *b);
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -31,6 +31,17 @@ inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
}
}
static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count,
int *in_count) {
*out_count = 1;
*in_count = 1;
for (int i = 0; i < shape_size; i++) {
if (i < axis) *out_count = (*out_count) * shape[i];
if (i == axis) *axis_count = shape[axis];
if (i > axis) *in_count = (*in_count) * shape[i];
}
}
static const unsigned int FP32_BIT_SIZE = 32;
static const unsigned int FP32_EXPONENT_BIAS = 127;
static const unsigned int FP32_SIGNIFICAND = 23;

View File

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/argminmax_fp32.h"
#include "src/runtime/kernel/arm/base/argminmax_base.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
@ -52,25 +52,38 @@ int ArgMinMaxCPUKernel::ReSize() {
}
int ArgMinMaxCPUKernel::Run() {
float *input_data = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
float *output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
float *output_value = nullptr;
if (out_tensors_.size() == 2) {
output_value = reinterpret_cast<float *>(out_tensors_.at(1)->data_c());
}
auto input = in_tensors_.at(0);
auto shape = input->shape();
auto shape = in_tensors_.at(0)->shape();
auto input_data = input->data_c();
auto output_data = out_tensors_.at(0)->data_c();
void *output_value = nullptr;
if (out_tensors_.size() == 2) {
output_value = out_tensors_.at(1)->data_c();
}
MS_ASSERT(context_->allocator != nullptr);
if (arg_param_->topk_ > 1 || arg_param_->keep_dims_) {
arg_param_->arg_elements_ =
reinterpret_cast<ArgElement *>(context_->allocator->Malloc(sizeof(ArgElement) * shape[arg_param_->axis_]));
if (arg_param_->arg_elements_ == nullptr) {
MS_LOG(ERROR) << "malloc memroy fail!";
MS_LOG(ERROR) << "malloc memory fail!";
return RET_ERROR;
}
}
ArgMinMaxFp32(input_data, output_data, output_value, reinterpret_cast<const int *>(shape.data()), arg_param_);
if (input->data_type() == kNumberTypeFloat32) {
ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<float *>(output_data),
reinterpret_cast<float *>(output_value), shape.data(), arg_param_);
#ifdef ENABLE_ARM64
} else if (input->data_type() == kNumberTypeFloat16) {
ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<float16_t *>(output_data),
reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_);
#endif
} else {
MS_LOG(ERROR) << "unsupported data type!";
}
context_->allocator->Free(arg_param_->arg_elements_);
arg_param_->arg_elements_ = nullptr;
return RET_OK;
@ -78,4 +91,6 @@ int ArgMinMaxCPUKernel::Run() {
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, LiteKernelCreator<ArgMinMaxCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, LiteKernelCreator<ArgMinMaxCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMax, LiteKernelCreator<ArgMinMaxCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMin, LiteKernelCreator<ArgMinMaxCPUKernel>)
} // namespace mindspore::kernel

View File

@ -13,12 +13,15 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_
#include <vector>
#include "include/errorcode.h"
#include "nnacl/fp32/arg_min_max_fp32.h"
#ifdef ENABLE_ARM64
#include "nnacl/fp16/arg_min_max_fp16.h"
#endif
#include "nnacl/common_func.h"
#include "src/lite_kernel.h"
@ -43,4 +46,4 @@ class ArgMinMaxCPUKernel : public LiteKernel {
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_