diff --git a/mindspore/lite/nnacl/arg_min_max_parameter.h b/mindspore/lite/nnacl/arg_min_max_parameter.h index 949923d052a..9569d958ff4 100644 --- a/mindspore/lite/nnacl/arg_min_max_parameter.h +++ b/mindspore/lite/nnacl/arg_min_max_parameter.h @@ -17,14 +17,22 @@ #ifndef MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_ #define MINDSPORE_LITE_NNACL_ARG_MIN_MAX_PARAMETER_H_ +#ifdef ENABLE_ARM64 +#include +#endif #include "nnacl/op_base.h" +typedef int (*COMPARE_FUNCTION)(const void *a, const void *b); + typedef struct ArgElement { uint32_t index_; union ArgData { int8_t i8_data_; int32_t i_data_; float f_data_; +#ifdef ENABLE_ARM64 + float16_t f16_data_; +#endif } data_; } ArgElement; diff --git a/mindspore/lite/nnacl/fp16/arg_min_max_fp16.c b/mindspore/lite/nnacl/fp16/arg_min_max_fp16.c new file mode 100644 index 00000000000..7f7aef0b39e --- /dev/null +++ b/mindspore/lite/nnacl/fp16/arg_min_max_fp16.c @@ -0,0 +1,240 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "nnacl/fp16/arg_min_max_fp16.h" + +int ArgCompareAscFp16(const void *a, const void *b) { + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + if (b_value > a_value) { + return -1; + } + if (b_value < a_value) { + return 1; + } + + return 0; +} + +int ArgCompareDescFp16(const void *a, const void *b) { + float16_t b_value = ((ArgElement *)b)->data_.f16_data_; + float16_t a_value = ((ArgElement *)a)->data_.f16_data_; + if (b_value > a_value) { + return 1; + } + if (b_value < a_value) { + return -1; + } + + return 0; +} + +void ArgMaxTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value, + const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = -FLT_MAX; + float16_t index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp > value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinTopK1Fp16(const float16_t *input, float16_t *output, float16_t *output_value, + const ArgMinMaxParameter *param, int pre_axis_count, int axis_count, int after_axis_count) { + bool out_value = param->out_value_; + for (int i = 0; i < pre_axis_count; ++i) { + size_t output_offset = i * after_axis_count; + size_t input_offset = output_offset * axis_count; + for (int j = 0; j < after_axis_count; ++j) { + float16_t value = FLT_MAX; + float16_t index = 0.0f; + for (int k = 0; k < axis_count; ++k) { + float16_t value_tmp = input[input_offset + k * after_axis_count + j]; + if (value_tmp < value) { + value = value_tmp; + index = k; + } + } + output[output_offset + j] = out_value ? value : index; + if (output_value != NULL) { + output_value[output_offset + j] = value; + } + } + } +} + +void ArgMinMaxDim0Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) { + for (int32_t i = 0; i < param->in_strides_[0]; ++i) { + for (int j = 0; j < in_shape[0]; ++j) { + size_t offset = param->in_strides_[0] * j + i; + param->arg_elements_[j].index_ = j; + param->arg_elements_[j].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape[0], sizeof(ArgElement), *compare_func); + for (int j = 0; j < param->topk_; ++j) { + size_t out_offset = j * param->out_strides_[0] + i; + output[out_offset] = param->out_value_ ? param->arg_elements_[j].data_.f_data_ : param->arg_elements_[j].index_; + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[j].data_.f_data_; + } + } + } + return; +} + +void ArgMinMaxDim1Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < param->in_strides_[1]; ++j) { + for (int k = 0; k < in_shape1; ++k) { + size_t offset = param->in_strides_[1] * k + in_dim0_offset + j; + param->arg_elements_[k].index_ = k; + param->arg_elements_[k].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape1, sizeof(ArgElement), *compare_func); + for (int k = 0; k < param->topk_; ++k) { + size_t out_offset = out_dim0_offset + j + k * param->out_strides_[1]; + output[out_offset] = param->out_value_ ? param->arg_elements_[k].data_.f_data_ : param->arg_elements_[k].index_; + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[k].data_.f_data_; + } + } + } + } + return; +} + +void ArgMinMaxDim2Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < param->in_strides_[2]; ++k) { + for (int l = 0; l < in_shape2; ++l) { + size_t offset = param->in_strides_[2] * l + k + in_dim1_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape2, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim1_offset + k + l * param->out_strides_[2]; + + output[out_offset] = + param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_; + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } + } +} + +void ArgMinMaxDim3Fp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) { + int in_shape1 = in_shape[1]; + int in_shape2 = in_shape[2]; + int in_shape3 = in_shape[3]; + for (int i = 0; i < in_shape[0]; ++i) { + size_t in_dim0_offset = i * param->in_strides_[0]; + size_t out_dim0_offset = i * param->out_strides_[0]; + for (int j = 0; j < in_shape1; ++j) { + size_t in_dim1_offset = j * param->in_strides_[1] + in_dim0_offset; + size_t out_dim1_offset = j * param->out_strides_[1] + out_dim0_offset; + for (int k = 0; k < in_shape2; ++k) { + size_t in_dim2_offset = k * param->in_strides_[2] + in_dim1_offset; + size_t out_dim2_offset = k * param->out_strides_[2] + out_dim1_offset; + for (int l = 0; l < in_shape3; ++l) { + size_t offset = l + in_dim2_offset; + param->arg_elements_[l].index_ = l; + param->arg_elements_[l].data_.f_data_ = input[offset]; + } + qsort(param->arg_elements_, in_shape3, sizeof(ArgElement), *compare_func); + for (int l = 0; l < param->topk_; ++l) { + size_t out_offset = out_dim2_offset + l; + output[out_offset] = + param->out_value_ ? param->arg_elements_[l].data_.f_data_ : param->arg_elements_[l].index_; + if (output_value != NULL) { + output_value[out_offset] = param->arg_elements_[l].data_.f_data_; + } + } + } + } + } +} + +void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param) { + if (param->topk_ == 1) { + int pre_axis_count = 1; + int axis_count = 1; + int after_axis_count = 1; + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + + if (param->get_max_) { + ArgMaxTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } else { + ArgMinTopK1Fp16(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); + } + return; + } + + COMPARE_FUNCTION compare_function = NULL; + if (param->get_max_) { + compare_function = ArgCompareDescFp16; + } else { + compare_function = ArgCompareAscFp16; + } + + switch (param->axis_) { + case 0: + ArgMinMaxDim0Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 1: + ArgMinMaxDim1Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 2: + ArgMinMaxDim2Fp16(input, output, output_value, in_shape, param, compare_function); + break; + case 3: + ArgMinMaxDim3Fp16(input, output, output_value, in_shape, param, compare_function); + break; + } + return; +} diff --git a/mindspore/lite/nnacl/fp16/arg_min_max_fp16.h b/mindspore/lite/nnacl/fp16/arg_min_max_fp16.h new file mode 100644 index 00000000000..a7fecebdbe1 --- /dev/null +++ b/mindspore/lite/nnacl/fp16/arg_min_max_fp16.h @@ -0,0 +1,32 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_ +#define MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_ + +#include +#include "nnacl/arg_min_max_parameter.h" +#include "nnacl/nnacl_common.h" + +#ifdef __cplusplus +extern "C" { +#endif +void ArgMinMaxFp16(const float16_t *input, float16_t *output, float16_t *output_value, const int *in_shape, + const ArgMinMaxParameter *param); +#ifdef __cplusplus +} +#endif + +#endif // MINDSPORE_LITE_NNACL_FP16_ARG_MIN_MAX_H_ diff --git a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c index dcc03f1e1c5..dd4024d49c8 100644 --- a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c +++ b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.c @@ -91,21 +91,6 @@ void ArgMinTopK1(const float *input, float *output, float *output_value, const A } } -void GetCalcParameter(const int *shape, int dims_number, int axis, int *pre_axis_count, int *axis_count, - int *after_axis_count) { - *pre_axis_count = 1; - for (int i = 0; i < axis; ++i) { - *pre_axis_count = (*pre_axis_count) * shape[i]; - } - - *axis_count = shape[axis]; - - *after_axis_count = 1; - for (int i = axis + 1; i < dims_number; ++i) { - *after_axis_count = (*after_axis_count) * shape[i]; - } -} - void ArgMinMaxDim0(const float *input, float *output, float *output_value, const int *in_shape, const ArgMinMaxParameter *param, COMPARE_FUNCTION compare_func) { for (int32_t i = 0; i < param->in_strides_[0]; ++i) { @@ -221,7 +206,7 @@ void ArgMinMaxFp32(const float *input, float *output, float *output_value, const int pre_axis_count = 1; int axis_count = 1; int after_axis_count = 1; - GetCalcParameter(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); + ComputeAxisDims(in_shape, param->dims_size_, param->axis_, &pre_axis_count, &axis_count, &after_axis_count); if (param->get_max_) { ArgMaxTopK1(input, output, output_value, param, pre_axis_count, axis_count, after_axis_count); diff --git a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.h b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.h index fa0a3c20a67..18146432c6f 100644 --- a/mindspore/lite/nnacl/fp32/arg_min_max_fp32.h +++ b/mindspore/lite/nnacl/fp32/arg_min_max_fp32.h @@ -16,10 +16,9 @@ #ifndef MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_ #define MINDSPORE_LITE_NNACL_FP32_ARG_MIN_MAX_H_ +#include "nnacl/nnacl_common.h" #include "nnacl/arg_min_max_parameter.h" -typedef int (*COMPARE_FUNCTION)(const void *a, const void *b); - #ifdef __cplusplus extern "C" { #endif diff --git a/mindspore/lite/nnacl/nnacl_common.h b/mindspore/lite/nnacl/nnacl_common.h index ae1adccf463..365257c9588 100644 --- a/mindspore/lite/nnacl/nnacl_common.h +++ b/mindspore/lite/nnacl/nnacl_common.h @@ -31,6 +31,17 @@ inline void ComputeStrides(const int *shape, int *strides, const int ndim) { } } +static inline void ComputeAxisDims(const int *shape, int shape_size, int axis, int *out_count, int *axis_count, + int *in_count) { + *out_count = 1; + *in_count = 1; + for (int i = 0; i < shape_size; i++) { + if (i < axis) *out_count = (*out_count) * shape[i]; + if (i == axis) *axis_count = shape[axis]; + if (i > axis) *in_count = (*in_count) * shape[i]; + } +} + static const unsigned int FP32_BIT_SIZE = 32; static const unsigned int FP32_EXPONENT_BIAS = 127; static const unsigned int FP32_SIGNIFICAND = 23; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.cc similarity index 69% rename from mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.cc rename to mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.cc index 2a89108d0c7..660997a61ce 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/runtime/kernel/arm/fp32/argminmax_fp32.h" +#include "src/runtime/kernel/arm/base/argminmax_base.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" @@ -52,25 +52,38 @@ int ArgMinMaxCPUKernel::ReSize() { } int ArgMinMaxCPUKernel::Run() { - float *input_data = reinterpret_cast(in_tensors_.at(0)->data_c()); - float *output_data = reinterpret_cast(out_tensors_.at(0)->data_c()); - float *output_value = nullptr; - if (out_tensors_.size() == 2) { - output_value = reinterpret_cast(out_tensors_.at(1)->data_c()); - } + auto input = in_tensors_.at(0); + auto shape = input->shape(); - auto shape = in_tensors_.at(0)->shape(); + auto input_data = input->data_c(); + auto output_data = out_tensors_.at(0)->data_c(); + void *output_value = nullptr; + if (out_tensors_.size() == 2) { + output_value = out_tensors_.at(1)->data_c(); + } MS_ASSERT(context_->allocator != nullptr); if (arg_param_->topk_ > 1 || arg_param_->keep_dims_) { arg_param_->arg_elements_ = reinterpret_cast(context_->allocator->Malloc(sizeof(ArgElement) * shape[arg_param_->axis_])); if (arg_param_->arg_elements_ == nullptr) { - MS_LOG(ERROR) << "malloc memroy fail!"; + MS_LOG(ERROR) << "malloc memory fail!"; return RET_ERROR; } } - ArgMinMaxFp32(input_data, output_data, output_value, reinterpret_cast(shape.data()), arg_param_); + if (input->data_type() == kNumberTypeFloat32) { + ArgMinMaxFp32(reinterpret_cast(input_data), reinterpret_cast(output_data), + reinterpret_cast(output_value), shape.data(), arg_param_); +#ifdef ENABLE_ARM64 + } else if (input->data_type() == kNumberTypeFloat16) { + ArgMinMaxFp16(reinterpret_cast(input_data), reinterpret_cast(output_data), + reinterpret_cast(output_value), shape.data(), arg_param_); + +#endif + } else { + MS_LOG(ERROR) << "unsupported data type!"; + } + context_->allocator->Free(arg_param_->arg_elements_); arg_param_->arg_elements_ = nullptr; return RET_OK; @@ -78,4 +91,6 @@ int ArgMinMaxCPUKernel::Run() { REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMax, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ArgMin, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMax, LiteKernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ArgMin, LiteKernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.h b/mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.h similarity index 83% rename from mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.h rename to mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.h index f9dc051443e..65dde8edf4e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/argminmax_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/argminmax_base.h @@ -13,12 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ +#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_ +#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_ #include #include "include/errorcode.h" #include "nnacl/fp32/arg_min_max_fp32.h" +#ifdef ENABLE_ARM64 +#include "nnacl/fp16/arg_min_max_fp16.h" +#endif #include "nnacl/common_func.h" #include "src/lite_kernel.h" @@ -43,4 +46,4 @@ class ArgMinMaxCPUKernel : public LiteKernel { }; } // namespace mindspore::kernel -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARGMINMAX_H_ +#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_ARGMINMAX_BASE_H_