forked from mindspore-Ecosystem/mindspore
!24459 [MS][LITE][develop] add fp16 prelu kernel
Merge pull request !24459 from sunsuodong/add_fp16_prelu
This commit is contained in:
commit
7b5cebed78
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
// * http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "nnacl/fp16/prelu_fp16.h"
|
||||
|
||||
#ifdef ENABLE_ARM64
|
||||
static inline void PReluFp164x32(const float16_t *in, float16_t *out, const float16_t *cur_slope, size_t step) {
|
||||
asm volatile(
|
||||
"mov x10, %[in]\n"
|
||||
"mov x11, %[out]\n"
|
||||
"mov x12, %[cur_slope]\n"
|
||||
"ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x12]\n"
|
||||
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n"
|
||||
"fmul v16.8h, v0.8h, v4.8h\n"
|
||||
"fmul v17.8h, v1.8h, v5.8h\n"
|
||||
"fmul v18.8h, v2.8h, v6.8h\n"
|
||||
"fmul v19.8h, v3.8h, v7.8h\n"
|
||||
"fcmgt v20.8h, v0.8h, #0\n"
|
||||
"fcmgt v21.8h, v1.8h, #0\n"
|
||||
"fcmgt v22.8h, v2.8h, #0\n"
|
||||
"fcmgt v23.8h, v3.8h, #0\n"
|
||||
"ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], %[step]\n"
|
||||
"bif v0.16b, v16.16b, v20.16b\n"
|
||||
"bif v1.16b, v17.16b, v21.16b\n"
|
||||
"bif v2.16b, v18.16b, v22.16b\n"
|
||||
"bif v3.16b, v19.16b, v23.16b\n"
|
||||
"fmul v8.8h, v24.8h, v4.8h\n"
|
||||
"fmul v9.8h, v25.8h, v5.8h\n"
|
||||
"fmul v10.8h, v26.8h, v6.8h\n"
|
||||
"fmul v11.8h, v27.8h, v7.8h\n"
|
||||
"st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n"
|
||||
"fcmgt v12.8h, v24.8h, #0\n"
|
||||
"fcmgt v13.8h, v25.8h, #0\n"
|
||||
"fcmgt v14.8h, v26.8h, #0\n"
|
||||
"fcmgt v15.8h, v27.8h, #0\n"
|
||||
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], %[step]\n"
|
||||
"bif v24.16b, v8.16b, v12.16b\n"
|
||||
"bif v25.16b, v9.16b, v13.16b\n"
|
||||
"bif v26.16b, v10.16b, v14.16b\n"
|
||||
"bif v27.16b, v11.16b, v15.16b\n"
|
||||
"fmul v16.8h, v0.8h, v4.8h\n"
|
||||
"fmul v17.8h, v1.8h, v5.8h\n"
|
||||
"fmul v18.8h, v2.8h, v6.8h\n"
|
||||
"fmul v19.8h, v3.8h, v7.8h\n"
|
||||
"st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11], %[step]\n"
|
||||
"fcmgt v20.8h, v0.8h, #0\n"
|
||||
"fcmgt v21.8h, v1.8h, #0\n"
|
||||
"fcmgt v22.8h, v2.8h, #0\n"
|
||||
"fcmgt v23.8h, v3.8h, #0\n"
|
||||
"ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10]\n"
|
||||
"bif v0.16b, v16.16b, v20.16b\n"
|
||||
"bif v1.16b, v17.16b, v21.16b\n"
|
||||
"bif v2.16b, v18.16b, v22.16b\n"
|
||||
"bif v3.16b, v19.16b, v23.16b\n"
|
||||
"fmul v8.8h, v24.8h, v4.8h\n"
|
||||
"fmul v9.8h, v25.8h, v5.8h\n"
|
||||
"fmul v10.8h, v26.8h, v6.8h\n"
|
||||
"fmul v11.8h, v27.8h, v7.8h\n"
|
||||
"fcmgt v12.8h, v24.8h, #0\n"
|
||||
"fcmgt v13.8h, v25.8h, #0\n"
|
||||
"fcmgt v14.8h, v26.8h, #0\n"
|
||||
"fcmgt v15.8h, v27.8h, #0\n"
|
||||
"st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x11], %[step]\n"
|
||||
"bif v24.16b, v8.16b, v12.16b\n"
|
||||
"bif v25.16b, v9.16b, v13.16b\n"
|
||||
"bif v26.16b, v10.16b, v14.16b\n"
|
||||
"bif v27.16b, v11.16b, v15.16b\n"
|
||||
"st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x11]\n"
|
||||
:
|
||||
: [ in ] "r"(in), [ out ] "r"(out), [ cur_slope ] "r"(cur_slope), [ step ] "r"(step)
|
||||
: "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13",
|
||||
"v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27");
|
||||
}
|
||||
#endif
|
||||
|
||||
void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel) {
|
||||
int i = start;
|
||||
#ifdef ENABLE_ARM64
|
||||
for (; i <= end - C4NUM; i += C4NUM) {
|
||||
const float16_t *cur_in = input + i * channel;
|
||||
float16_t *cur_out = output + i * channel;
|
||||
int j = 0;
|
||||
for (; j <= channel - C32NUM; j += C32NUM) {
|
||||
const float16_t *in = cur_in + j;
|
||||
float16_t *out = cur_out + j;
|
||||
const float16_t *cur_slope = slope + j;
|
||||
size_t step = channel * sizeof(float16_t);
|
||||
PReluFp164x32(in, out, cur_slope, step);
|
||||
}
|
||||
for (; j < channel; j++) {
|
||||
cur_out[j] = (cur_in[j] > 0) ? cur_in[j] : (cur_in[j] * slope[j]);
|
||||
cur_out[j + channel] = (cur_in[j + channel] > 0) ? cur_in[j + channel] : cur_in[j + channel] * slope[j];
|
||||
cur_out[j + 2 * channel] =
|
||||
(cur_in[j + 2 * channel] > 0) ? cur_in[j + 2 * channel] : (cur_in[j + 2 * channel] * slope[j]);
|
||||
cur_out[j + 3 * channel] =
|
||||
(cur_in[j + 3 * channel] > 0) ? cur_in[j + 3 * channel] : (cur_in[j + 3 * channel] * slope[j]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
for (; i < end; i++) {
|
||||
const float16_t *cur_in = input + i * channel;
|
||||
float16_t *cur_out = output + i * channel;
|
||||
int j = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; j <= channel - C8NUM; j += C8NUM) {
|
||||
float16x8_t in = vld1q_f16(cur_in + j);
|
||||
float16x8_t s = vld1q_f16(slope + j);
|
||||
float16x8_t mul = vmulq_f16(in, s);
|
||||
uint16x8_t mask = vcgtq_f16(in, vmovq_n_f16(0.0f));
|
||||
vst1q_f16(cur_out + j, vbslq_f16(mask, in, mul));
|
||||
}
|
||||
#endif
|
||||
for (; j < channel; j++) {
|
||||
cur_out[j] = cur_in[j] > 0 ? cur_in[j] : cur_in[j] * slope[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end) {
|
||||
int i = start;
|
||||
#ifdef ENABLE_NEON
|
||||
for (; i <= end - C8NUM; i += C8NUM) {
|
||||
float16x8_t src_tmp = vld1q_f16(input + i);
|
||||
float16x8_t mul_tmp = vmulq_n_f16(src_tmp, slope);
|
||||
uint16x8_t mask = vcgtq_f16(src_tmp, vmovq_n_f16(0.0f));
|
||||
vst1q_f16(output + i, vbslq_f16(mask, src_tmp, mul_tmp));
|
||||
}
|
||||
#endif
|
||||
for (; i < end; i++) {
|
||||
output[i] = input[i] > 0 ? input[i] : input[i] * slope;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_NNACL_FP16_PRELU_FP16_H_
|
||||
#define MINDSPORE_NNACL_FP16_PRELU_FP16_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
void PReluFp16(const float16_t *input, float16_t *output, const float16_t *slope, int start, int end, int channel);
|
||||
|
||||
void PReluShareChannelFp16(const float16_t *input, float16_t *output, float16_t slope, int start, int end);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MINDSPORE_NNACL_FP16_PRELU_FP16_H_
|
|
@ -17,7 +17,6 @@
|
|||
#define MINDSPORE_NNACL_FP32_PRELU_H_
|
||||
|
||||
#include "nnacl/op_base.h"
|
||||
#include "nnacl/prelu_parameter.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
|
|
@ -95,12 +95,12 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
|
|||
#define MS_FLOAT16X8 float16x8_t
|
||||
#define MS_FLOAT16X4 float16x4_t
|
||||
#define MS_MOVQ_F16 vmovq_n_f16
|
||||
#define MS_STQ_F16 vst1q_f16
|
||||
#define MS_STQ_F16(ptr, val) vst1q_f16(ptr, val)
|
||||
#define MS_ST_F16 vst1_f16
|
||||
#define MS_MINQ_F16 vminq_f16
|
||||
#define MS_MAXQ_F16 vmaxq_f16
|
||||
#define MS_LDQ_F16 vld1q_f16
|
||||
#define MS_LD_F16 vld1_f16
|
||||
#define MS_LDQ_F16(ptr) vld1q_f16(ptr)
|
||||
#define MS_LD_F16(ptr) vld1_f16(ptr)
|
||||
#define MS_ADDQ_F16 vaddq_f16
|
||||
#define MS_SUBQ_F16 vsubq_f16
|
||||
#define MS_MULQ_F16 vmulq_f16
|
||||
|
|
|
@ -22,9 +22,7 @@ typedef struct PReluParameter {
|
|||
// Primitive parameter
|
||||
OpParameter op_parameter_;
|
||||
// other parameter
|
||||
float *slope_;
|
||||
bool channelShared;
|
||||
int tile_block_;
|
||||
int channel_num_;
|
||||
int input_num_;
|
||||
} PReluParameter;
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp16/prelu_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp16/prelu_fp16.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_PReLUFusion;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int PReluFp16CPUKernel::DoExcute(int task_id) {
|
||||
int thread_num = param_->op_parameter_.thread_num_;
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "thread_num is 0!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int num = param_->channelShared ? param_->input_num_ : param_->input_num_ / param_->channel_num_;
|
||||
int step = UP_DIV(num, thread_num);
|
||||
int start = task_id * step;
|
||||
int end = MSMIN(start + step, num);
|
||||
|
||||
if (param_->channelShared) {
|
||||
PReluShareChannelFp16(static_cast<float16_t *>(input_data_), static_cast<float16_t *>(output_data_),
|
||||
static_cast<float16_t *>(slope_data_)[0], start, end);
|
||||
} else {
|
||||
PReluFp16(static_cast<float16_t *>(input_data_), static_cast<float16_t *>(output_data_),
|
||||
static_cast<float16_t *>(slope_data_), start, end, param_->channel_num_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_PReLUFusion, LiteKernelCreator<PReluFp16CPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_PRELU_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_PRELU_FP16_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp32/prelu_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class PReluFp16CPUKernel : public PReluCPUKernel {
|
||||
public:
|
||||
PReluFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: PReluCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~PReluFp16CPUKernel() = default;
|
||||
|
||||
int DoExcute(int task_id) override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_PRELU_FP16_H_
|
|
@ -18,6 +18,7 @@
|
|||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/prelu_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -37,12 +38,16 @@ static int PReluRun(void *cdata, int task_id, float lhs_scale, float rhs_scale)
|
|||
}
|
||||
|
||||
int PReluCPUKernel::Prepare() {
|
||||
constexpr int kSlopeIndex = 1;
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
CHECK_NULL_RETURN(in_tensors_[kInputIndex]);
|
||||
CHECK_NULL_RETURN(in_tensors_[kSlopeIndex]);
|
||||
CHECK_NULL_RETURN(out_tensors_[kOutputIndex]);
|
||||
if (in_tensors_[1]->ElementsNum() == 1) {
|
||||
prelu_param_->channelShared = true;
|
||||
param_->channelShared = true;
|
||||
} else {
|
||||
prelu_param_->channelShared = false;
|
||||
param_->channelShared = false;
|
||||
}
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
|
@ -51,59 +56,47 @@ int PReluCPUKernel::Prepare() {
|
|||
}
|
||||
|
||||
int PReluCPUKernel::DoExcute(int task_id) {
|
||||
int thread_num = prelu_param_->op_parameter_.thread_num_;
|
||||
int thread_num = param_->op_parameter_.thread_num_;
|
||||
if (thread_num == 0) {
|
||||
MS_LOG(ERROR) << "thread_num is 0!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prelu_param_->channelShared) {
|
||||
int step = UP_DIV(prelu_param_->input_num_, thread_num);
|
||||
int start = task_id * step;
|
||||
int end = MSMIN(start + step, prelu_param_->input_num_);
|
||||
PReluShareChannel(input_data_, output_data_, prelu_param_->slope_[0], start, end);
|
||||
int num = param_->channelShared ? param_->input_num_ : param_->input_num_ / param_->channel_num_;
|
||||
int step = UP_DIV(num, thread_num);
|
||||
int start = task_id * step;
|
||||
int end = MSMIN(start + step, num);
|
||||
|
||||
if (param_->channelShared) {
|
||||
PReluShareChannel(static_cast<float *>(input_data_), static_cast<float *>(output_data_),
|
||||
static_cast<float *>(slope_data_)[0], start, end);
|
||||
} else {
|
||||
int step = UP_DIV(prelu_param_->tile_block_, thread_num);
|
||||
int start = task_id * step;
|
||||
int end = MSMIN(start + step, prelu_param_->tile_block_);
|
||||
PRelu(input_data_, output_data_, prelu_param_->slope_, start, end, prelu_param_->channel_num_);
|
||||
PRelu(static_cast<float *>(input_data_), static_cast<float *>(output_data_), static_cast<float *>(slope_data_),
|
||||
start, end, param_->channel_num_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PReluCPUKernel::ReSize() {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto in_shape = input_tensor->shape();
|
||||
auto n_dim = in_shape.size();
|
||||
auto channel_num = in_shape.at(n_dim - 1);
|
||||
int input_plane = 1;
|
||||
for (size_t i = 0; i < n_dim - 1; ++i) {
|
||||
input_plane *= in_shape.at(i);
|
||||
}
|
||||
MS_CHECK_FALSE(INT_MUL_OVERFLOW(input_plane, channel_num), RET_ERROR);
|
||||
prelu_param_->input_num_ = input_plane * channel_num;
|
||||
prelu_param_->tile_block_ = input_plane;
|
||||
prelu_param_->channel_num_ = channel_num;
|
||||
auto &input = in_tensors_[kInputIndex];
|
||||
param_->input_num_ = input->ElementsNum();
|
||||
param_->channel_num_ = input->Channel();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int PReluCPUKernel::Run() {
|
||||
auto input_tensor = in_tensors_[0];
|
||||
input_data_ = reinterpret_cast<float *>(input_tensor->data());
|
||||
output_data_ = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->data());
|
||||
constexpr int kSlopeIndex = 1;
|
||||
input_data_ = in_tensors_[kInputIndex]->data();
|
||||
slope_data_ = in_tensors_[kSlopeIndex]->data();
|
||||
output_data_ = out_tensors_[kOutputIndex]->data();
|
||||
CHECK_NULL_RETURN(input_data_);
|
||||
CHECK_NULL_RETURN(slope_data_);
|
||||
CHECK_NULL_RETURN(output_data_);
|
||||
|
||||
// negative slope tensor
|
||||
auto negative_slope_tensor = in_tensors_.at(1);
|
||||
CHECK_NULL_RETURN(negative_slope_tensor->data());
|
||||
prelu_param_->slope_ = reinterpret_cast<float *>(negative_slope_tensor->data());
|
||||
|
||||
auto ret = ParallelLaunch(this->ms_context_, PReluRun, this, prelu_param_->op_parameter_.thread_num_);
|
||||
auto ret = ParallelLaunch(this->ms_context_, PReluRun, this, param_->op_parameter_.thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "PRelu Run error: error_code[" << ret << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "include/context.h"
|
||||
#include "nnacl/fp32/prelu_fp32.h"
|
||||
#include "nnacl/prelu_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class PReluCPUKernel : public InnerKernel {
|
||||
|
@ -27,19 +27,20 @@ class PReluCPUKernel : public InnerKernel {
|
|||
PReluCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
prelu_param_ = reinterpret_cast<PReluParameter *>(op_parameter_);
|
||||
param_ = reinterpret_cast<PReluParameter *>(op_parameter_);
|
||||
}
|
||||
~PReluCPUKernel() = default;
|
||||
|
||||
int Prepare() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoExcute(int task_id);
|
||||
virtual int DoExcute(int task_id);
|
||||
|
||||
private:
|
||||
PReluParameter *prelu_param_;
|
||||
float *input_data_ = nullptr;
|
||||
float *output_data_ = nullptr;
|
||||
protected:
|
||||
PReluParameter *param_;
|
||||
void *input_data_ = nullptr;
|
||||
void *slope_data_ = nullptr;
|
||||
void *output_data_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_PRELU_H_
|
||||
|
|
|
@ -17,7 +17,7 @@ hiai_face_RFB-Epoch-170-no-transpose 4
|
|||
tracking 4
|
||||
mtk_landmark 1
|
||||
mtk_pose_tuku 1
|
||||
mtk_face_recognition_v1 20
|
||||
mtk_face_recognition_v1 21
|
||||
mtk_2012_ATLANTA_10class_20190614_v41 4
|
||||
mtk_detect-deeper-halfdeeper-mbv1-lastearlySSD-shortcut-400-400_nopostprocess_simplified 4
|
||||
# mtk_detect-deeper-halfdeeper-mbv1-shortcut-400-400_nopostprocess_simplified: precision is 5%
|
||||
|
@ -38,11 +38,11 @@ ml_hardware_pose 2
|
|||
ml_bank_recog 0.1
|
||||
2012_ATLANTA_10class_20190131_v4.0 12
|
||||
mnet 12
|
||||
recognition 10
|
||||
recognition 10.8
|
||||
ml_face_landmark 1
|
||||
model_hebing_3branch 40
|
||||
hiai_cv_focusShootOCRModel_07 3
|
||||
hiai_cv_focusShootOCRModel_03 60
|
||||
hiai_cv_focusShootOCRModel_03 169
|
||||
hiai_cv_focusShootOCRModel_01 14
|
||||
hiai_face_hat1 1.7
|
||||
hiai_cv_focusShootOCRModel_04 8
|
||||
|
@ -51,7 +51,7 @@ hiai_cpu_face_hat 1.7
|
|||
hiai_video_seg 1
|
||||
hiai_semantic_seg 3
|
||||
hiai_human_seg 28
|
||||
hiai_face_recognition_1 10
|
||||
hiai_face_recognition_1 10.8
|
||||
hiai_cpu_face_detect 4.5
|
||||
hiai_cpu_face_attr 82.3 # divded by small number causes big bias
|
||||
hiai_face_attr1 82.3 # divded by small number causes big bias
|
||||
|
|
|
@ -62,7 +62,7 @@ ml_ei_facedetection.onnx 2
|
|||
#ml_video_edit_art_generate.onnx #mul operator overflows, not suitable for fp16
|
||||
#ml_voice_detect.onnx #conv operator overflows, not suitable for fp16
|
||||
#ml_location_lane_counter.onnx has very small values during op computation (<1e-6), which causes the precision variation
|
||||
ml_location_lane_counter.onnx 8
|
||||
ml_location_lane_counter.onnx 8.3
|
||||
ml_location_lane_counter0.onnx 1.0
|
||||
#The encoder an decoder model are used in ml_asr scene, both have value overflow. Not suitable for fp16.
|
||||
#But added for guarding process.
|
||||
|
@ -103,7 +103,7 @@ ml_asr_decoder_202103.onnx;2;1,64,512:1,64 0.5
|
|||
ml_video_edit_makeup_mobilenetv203.onnx 4
|
||||
# The input of ml_video_edit_hair_dyeing_migrate_v2.onnx should be between [0, 1]
|
||||
ml_video_edit_hair_dyeing_migrate_v2.onnx;4 2.5
|
||||
Q888_CV_face_recognition_self.onnx 3.6
|
||||
Q888_CV_face_recognition_self.onnx 3.9
|
||||
ml_video_edit_hair_dyeing_migrate_v2_fix.onnx;4 3
|
||||
ml_intelligent_cockpit_model.onnx;3;1,32:1,32:1,32 3.5
|
||||
CloudBU_FSRCNN_RTC_8ch_3450_QP9.onnx;1;1,225,225,3 1.5
|
||||
|
|
|
@ -130,7 +130,7 @@ mtk_model_emotions_0727_nosoftmax.tflite 2
|
|||
mtk_model_normalize_object_scene_ps_20200826_f32_no_softmax.tflite 22
|
||||
mtk_276landmark_0913.tflite 16
|
||||
mtk_face_recognition.tflite 8
|
||||
mtk_convert_model.tflite 5
|
||||
mtk_convert_model.tflite 5.3
|
||||
smartreply.tflite 0.1
|
||||
mindspore_text_classification_tflite.tflite 9.2 # small output causes big bias
|
||||
#ml_location.tflite 0.1
|
||||
|
|
Loading…
Reference in New Issue