forked from mindspore-Ecosystem/mindspore
!11709 [MSLITE] fp16 argmin max
From: @ling_qiao_min Reviewed-by: Signed-off-by:
This commit is contained in:
commit
9557bef491
|
@ -17,14 +17,22 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_
|
||||
#define MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
typedef int (*COMPARE_FUNCTION)(const void *a, const void *b);
|
||||
|
||||
typedef struct ArgElement {
|
||||
uint32_t index_;
|
||||
union ArgData {
|
||||
int8_t i8_data_;
|
||||
int32_t i_data_;
|
||||
float f_data_;
|
||||
#ifdef ENABLE_ARM64
|
||||
float16_t f16_data_;
|
||||
#endif
|
||||
} data_;
|
||||
} ArgElement;
|
||||
|
||||
|
|
|
@ -0,0 +1,240 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "nnacl/fp16/arg_min_max_fp16.h"
|
||||
|
||||
int ArgCompareAscFp16(const void *a, const void *b) {
|
||||
float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
|
||||
float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
|
||||
if (b_value > a_value) {
|
||||
return -1;
|
||||
}
|
||||
if (b_value < a_value) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ArgCompareDescFp16(const void *a, const void *b) {
|
||||
float16_t b_value = ((ArgElement *)b)->data_.f16_data_;
|
||||
float16_t a_value = ((ArgElement *)a)->data_.f16_data_;
|
||||
if (b_value > a_value) {
|
||||
return 1;
|
||||
}
|
||||
if (b_value < a_value) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ArgMaxTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value,
|
||||
const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) {
|
||||
bool out_value = param->out_value_;
|
||||
for (int i = 0; i < pre_axis_count; ++i) {
|
||||
size_t output_offset = i * after_axis_count;
|
||||
size_t input_offset = output_offset * axis_count;
|
||||
for (int j = 0; j < after_axis_count; ++j) {
|
||||
float16_t value = -FLT_MAX;
|
||||
float16_t index = 0.0f;
|
||||
for (int k = 0; k < axis_count; ++k) {
|
||||
float16_t value_tmp = input[input_offset + k * after_axis_count + j];
|
||||
if (value_tmp > value) {
|
||||
value = value_tmp;
|
||||
index = k;
|
||||
}
|
||||
}
|
||||
output[output_offset + j] = out_value ? value : index;
|
||||
if (output_value != NULL) {
|
||||
output_value[output_offset + j] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMinTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value,
|
||||
const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) {
|
||||
bool out_value = param->out_value_;
|
||||
for (int i = 0; i < pre_axis_count; ++i) {
|
||||
size_t output_offset = i * after_axis_count;
|
||||
size_t input_offset = output_offset * axis_count;
|
||||
for (int j = 0; j < after_axis_count; ++j) {
|
||||
float16_t value = FLT_MAX;
|
||||
float16_t index = 0.0f;
|
||||
for (int k = 0; k < axis_count; ++k) {
|
||||
float16_t value_tmp = input[input_offset + k * after_axis_count + j];
|
||||
if (value_tmp < value) {
|
||||
value = value_tmp;
|
||||
index = k;
|
||||
}
|
||||
}
|
||||
output[output_offset + j] = out_value ? value : index;
|
||||
if (output_value != NULL) {
|
||||
output_value[output_offset + j] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMinMaxDim0Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
|
||||
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
|
||||
for (int j = 0; j < in_shape[0]; ++j) {
|
||||
size_t offset = param->in_strides_[0] * j + i;
|
||||
param->arg_elements_[j].index_ = j;
|
||||
param->arg_elements_[j].data_.f_data_ = input[offset];
|
||||
}
|
||||
qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func);
|
||||
for (int j = 0; j < param->topk_; ++j) {
|
||||
size_t out_offset = j * param->out_strides_[0] + i;
|
||||
output[out_offset] = param->out_value_ ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_;
|
||||
if (output_value != NULL) {
|
||||
output_value[out_offset] = param->arg_elements_[j].data_.f_data_;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ArgMinMaxDim1Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
|
||||
int in_shape1 = in_shape[1];
|
||||
for (int i = 0; i < in_shape[0]; ++i) {
|
||||
size_t in_dim0_offset = i * param->in_strides_[0];
|
||||
size_t out_dim0_offset = i * param->out_strides_[0];
|
||||
for (int j = 0; j < param->in_strides_[1]; ++j) {
|
||||
for (int k = 0; k < in_shape1; ++k) {
|
||||
size_t offset = param->in_strides_[1] * k + in_dim0_offset + j;
|
||||
param->arg_elements_[k].index_ = k;
|
||||
param->arg_elements_[k].data_.f_data_ = input[offset];
|
||||
}
|
||||
qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func);
|
||||
for (int k = 0; k < param->topk_; ++k) {
|
||||
size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1];
|
||||
output[out_offset] = param->out_value_ ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_;
|
||||
if (output_value != NULL) {
|
||||
output_value[out_offset] = param->arg_elements_[k].data_.f_data_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
|
||||
int in_shape1 = in_shape[1];
|
||||
int in_shape2 = in_shape[2];
|
||||
for (int i = 0; i < in_shape[0]; ++i) {
|
||||
size_t in_dim0_offset = i * param->in_strides_[0];
|
||||
size_t out_dim0_offset = i * param->out_strides_[0];
|
||||
for (int j = 0; j < in_shape1; ++j) {
|
||||
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
|
||||
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
|
||||
for (int k = 0; k < param->in_strides_[2]; ++k) {
|
||||
for (int l = 0; l < in_shape2; ++l) {
|
||||
size_t offset = param->in_strides_[2] * l + k + in_dim1_offset;
|
||||
param->arg_elements_[l].index_ = l;
|
||||
param->arg_elements_[l].data_.f_data_ = input[offset];
|
||||
}
|
||||
qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func);
|
||||
for (int l = 0; l < param->topk_; ++l) {
|
||||
size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2];
|
||||
|
||||
output[out_offset] =
|
||||
param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
|
||||
if (output_value != NULL) {
|
||||
output_value[out_offset] = param->arg_elements_[l].data_.f_data_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
|
||||
int in_shape1 = in_shape[1];
|
||||
int in_shape2 = in_shape[2];
|
||||
int in_shape3 = in_shape[3];
|
||||
for (int i = 0; i < in_shape[0]; ++i) {
|
||||
size_t in_dim0_offset = i * param->in_strides_[0];
|
||||
size_t out_dim0_offset = i * param->out_strides_[0];
|
||||
for (int j = 0; j < in_shape1; ++j) {
|
||||
size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset;
|
||||
size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset;
|
||||
for (int k = 0; k < in_shape2; ++k) {
|
||||
size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset;
|
||||
size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset;
|
||||
for (int l = 0; l < in_shape3; ++l) {
|
||||
size_t offset = l + in_dim2_offset;
|
||||
param->arg_elements_[l].index_ = l;
|
||||
param->arg_elements_[l].data_.f_data_ = input[offset];
|
||||
}
|
||||
qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func);
|
||||
for (int l = 0; l < param->topk_; ++l) {
|
||||
size_t out_offset = out_dim2_offset + l;
|
||||
output[out_offset] =
|
||||
param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_;
|
||||
if (output_value != NULL) {
|
||||
output_value[out_offset] = param->arg_elements_[l].data_.f_data_;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param) {
|
||||
if (param->topk_ == 1) {
|
||||
int pre_axis_count = 1;
|
||||
int axis_count = 1;
|
||||
int after_axis_count = 1;
|
||||
ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
|
||||
|
||||
if (param->get_max_) {
|
||||
ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
|
||||
} else {
|
||||
ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
COMPARE_FUNCTION compare_function = NULL;
|
||||
if (param->get_max_) {
|
||||
compare_function = ArgCompareDescFp16;
|
||||
} else {
|
||||
compare_function = ArgCompareAscFp16;
|
||||
}
|
||||
|
||||
switch (param->axis_) {
|
||||
case 0:
|
||||
ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function);
|
||||
break;
|
||||
case 1:
|
||||
ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function);
|
||||
break;
|
||||
case 2:
|
||||
ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function);
|
||||
break;
|
||||
case 3:
|
||||
ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function);
|
||||
break;
|
||||
}
|
||||
return;
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_
|
||||
|
||||
#include <float.h>
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_
|
|
@ -91,21 +91,6 @@ void ArgMinTopK1(const float *input, float *output, float *output_value, const A
|
|||
}
|
||||
}
|
||||
|
||||
void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count,
|
||||
int *after_axis_count) {
|
||||
*pre_axis_count = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
*pre_axis_count = (*pre_axis_count) * shape[i];
|
||||
}
|
||||
|
||||
*axis_count = shape[axis];
|
||||
|
||||
*after_axis_count = 1;
|
||||
for (int i = axis + 1; i < dims_number; ++i) {
|
||||
*after_axis_count = (*after_axis_count) * shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
void ArgMinMaxDim0(const float *input, float *output, float *output_value, const int *in_shape,
|
||||
const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) {
|
||||
for (int32_t i = 0; i < param->in_strides_[0]; ++i) {
|
||||
|
@ -221,7 +206,7 @@ void ArgMinMaxFp32(const float *input, float *output, float *output_value, const
|
|||
int pre_axis_count = 1;
|
||||
int axis_count = 1;
|
||||
int after_axis_count = 1;
|
||||
GetCalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
|
||||
ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count);
|
||||
|
||||
if (param->get_max_) {
|
||||
ArgMaxTopK1(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count);
|
||||
|
|
|
@ -16,10 +16,9 @@
|
|||
#ifndef MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_
|
||||
#define MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_
|
||||
|
||||
#include "nnacl/nnacl_common.h"
|
||||
#include "nnacl/arg_min_max_parameter.h"
|
||||
|
||||
typedef int (*COMPARE_FUNCTION)(const void *a, const void *b);
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
|
|
@ -31,6 +31,17 @@ inline void ComputeStrides(const int *shape, int *strides, const int ndim) {
|
|||
}
|
||||
}
|
||||
|
||||
static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count,
|
||||
int *in_count) {
|
||||
*out_count = 1;
|
||||
*in_count = 1;
|
||||
for (int i = 0; i < shape_size; i++) {
|
||||
if (i < axis) *out_count = (*out_count) * shape[i];
|
||||
if (i == axis) *axis_count = shape[axis];
|
||||
if (i > axis) *in_count = (*in_count) * shape[i];
|
||||
}
|
||||
}
|
||||
|
||||
static const unsigned int FP32_BIT_SIZE = 32;
|
||||
static const unsigned int FP32_EXPONENT_BIAS = 127;
|
||||
static const unsigned int FP32_SIGNIFICAND = 23;
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp32/argminmax_fp32.h"
|
||||
#include "src/runtime/kernel/arm/base/argminmax_base.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
|
@ -52,25 +52,38 @@ int ArgMinMaxCPUKernel::ReSize() {
|
|||
}
|
||||
|
||||
int ArgMinMaxCPUKernel::Run() {
|
||||
float *input_data = reinterpret_cast<float *>(in_tensors_.at(0)->data_c());
|
||||
float *output_data = reinterpret_cast<float *>(out_tensors_.at(0)->data_c());
|
||||
float *output_value = nullptr;
|
||||
if (out_tensors_.size() == 2) {
|
||||
output_value = reinterpret_cast<float *>(out_tensors_.at(1)->data_c());
|
||||
}
|
||||
auto input = in_tensors_.at(0);
|
||||
auto shape = input->shape();
|
||||
|
||||
auto shape = in_tensors_.at(0)->shape();
|
||||
auto input_data = input->data_c();
|
||||
auto output_data = out_tensors_.at(0)->data_c();
|
||||
void *output_value = nullptr;
|
||||
if (out_tensors_.size() == 2) {
|
||||
output_value = out_tensors_.at(1)->data_c();
|
||||
}
|
||||
|
||||
MS_ASSERT(context_->allocator != nullptr);
|
||||
if (arg_param_->topk_ > 1 || arg_param_->keep_dims_) {
|
||||
arg_param_->arg_elements_ =
|
||||
reinterpret_cast<ArgElement *>(context_->allocator->Malloc(sizeof(ArgElement) * shape[arg_param_->axis_]));
|
||||
if (arg_param_->arg_elements_ == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc memroy fail!";
|
||||
MS_LOG(ERROR) << "malloc memory fail!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
ArgMinMaxFp32(input_data, output_data, output_value, reinterpret_cast<const int *>(shape.data()), arg_param_);
|
||||
if (input->data_type() == kNumberTypeFloat32) {
|
||||
ArgMinMaxFp32(reinterpret_cast<float *>(input_data), reinterpret_cast<float *>(output_data),
|
||||
reinterpret_cast<float *>(output_value), shape.data(), arg_param_);
|
||||
#ifdef ENABLE_ARM64
|
||||
} else if (input->data_type() == kNumberTypeFloat16) {
|
||||
ArgMinMaxFp16(reinterpret_cast<float16_t *>(input_data), reinterpret_cast<float16_t *>(output_data),
|
||||
reinterpret_cast<float16_t *>(output_value), shape.data(), arg_param_);
|
||||
|
||||
#endif
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupported data type!";
|
||||
}
|
||||
|
||||
context_->allocator->Free(arg_param_->arg_elements_);
|
||||
arg_param_->arg_elements_ = nullptr;
|
||||
return RET_OK;
|
||||
|
@ -78,4 +91,6 @@ int ArgMinMaxCPUKernel::Run() {
|
|||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, LiteKernelCreator<ArgMinMaxCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, LiteKernelCreator<ArgMinMaxCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMax, LiteKernelCreator<ArgMinMaxCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMin, LiteKernelCreator<ArgMinMaxCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -13,12 +13,15 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/arg_min_max_fp32.h"
|
||||
#ifdef ENABLE_ARM64
|
||||
#include "nnacl/fp16/arg_min_max_fp16.h"
|
||||
#endif
|
||||
#include "nnacl/common_func.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
|
@ -43,4 +46,4 @@ class ArgMinMaxCPUKernel : public LiteKernel {
|
|||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_
|
Loading…
Reference in New Issue