!22879 [MS][LITE][Develop] add lite fp16 kernel
Merge pull request !22879 from sunsuodong/add_fp16_kernel
This commit is contained in:
commit
dd32f03a26
|
@ -16,6 +16,7 @@
|
|||
#include "nnacl/fp16/activation_fp16.h"
|
||||
#include <float.h>
|
||||
#include "nnacl/fp32/exp_fp32.h"
|
||||
#include "nnacl/fp16/exp_fp16.h"
|
||||
#include "nnacl/errorcode.h"
|
||||
|
||||
int ReluFp16(const float16_t *src, float16_t *dst, int ele_num) {
|
||||
|
@ -249,3 +250,22 @@ int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate)
|
|||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
||||
int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha) {
|
||||
int i = 0;
|
||||
#ifdef ENABLE_NEON
|
||||
float16x8_t one = MS_MOVQ_F16(1.0f);
|
||||
for (; i <= length - 8; i += 8) {
|
||||
float16x8_t src_tmp = MS_LDQ_F16(src + i);
|
||||
float16x8_t exp_tmp = VexpFp16(src_tmp); // exp(x)
|
||||
exp_tmp = MS_SUBQ_F16(exp_tmp, one); // exp(x) - 1
|
||||
float16x8_t elu_tmp = MS_MULQ_N_F16(exp_tmp, alpha);
|
||||
uint16x8_t mask = MS_CMPGTQ_F16(src_tmp, MS_MOVQ_F16(0.0f));
|
||||
MS_STQ_F16(dst + i, vbslq_f16(elu_tmp, src_tmp, mask));
|
||||
}
|
||||
#endif
|
||||
for (; i < length; ++i) {
|
||||
dst[i] = src[i] > 0 ? src[i] : (expm1(src[i]) * alpha);
|
||||
}
|
||||
return NNACL_OK;
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ int HSigmoidFp16(const float16_t *src, float16_t *dst, int ele_num);
|
|||
int SwishFp16(const float16_t *src, float16_t *dst, int ele_num);
|
||||
int HardTanhFp16(const float16_t *src, int length, float16_t *dst, float min_val, float max_val);
|
||||
int GeluFp16(const float16_t *src, int length, float16_t *dst, bool approximate);
|
||||
int EluFp16(const float16_t *src, int length, float16_t *dst, float16_t alpha);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -23,9 +23,7 @@ static inline void simd_exp_fp16(float16x8_t input, float16_t *dst) {
|
|||
static float16x8_t maxv = {88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f, 88.0f};
|
||||
static float16x8_t minv = {-88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f, -88.0f};
|
||||
input = vmaxq_f16(minv, vminq_f16(input, maxv));
|
||||
float32x4_t input_low = vcvt_f32_f16(vget_low_f16(input));
|
||||
float32x4_t input_high = vcvt_f32_f16(vget_high_f16(input));
|
||||
vst1q_f16(dst, vcombine_f16(vcvt_f16_f32(VexpFp32(input_low)), vcvt_f16_f32(VexpFp32(input_high))));
|
||||
vst1q_f16(dst, VexpFp16(input));
|
||||
}
|
||||
#endif
|
||||
|
||||
|
|
|
@ -27,6 +27,14 @@ extern "C" {
|
|||
void ExpFp16(const float16_t *src, float16_t *dst, int num);
|
||||
int ExpFusionFp16(const float16_t *src, float16_t *dst, const ExpParameter *param, int task_id);
|
||||
|
||||
#ifdef ENABLE_NEON
|
||||
static inline float16x8_t VexpFp16(float16x8_t input) {
|
||||
float32x4_t input_low = MS_CVT_F32_F16(vget_low_f16(input));
|
||||
float32x4_t input_high = MS_CVT_F32_F16(vget_high_f16(input));
|
||||
return vcombine_f16(MS_CVT_F16_F32(VexpFp32(input_low)), MS_CVT_F16_F32(VexpFp32(input_high)));
|
||||
}
|
||||
#endif
|
||||
|
||||
static inline void single_exp_fp16(float16_t src, float16_t *dst) {
|
||||
static float param[] = {0.693147f, 1.0f / 120, 1.0f / 24, 1.0f / 6, 1.0f / 2, 1.0f};
|
||||
src = MSMAX(-88.0f, MSMIN(88.0f, src));
|
||||
|
|
|
@ -105,6 +105,8 @@ static inline float16x4_t ms_vcvt_f16_f32(float32x4_t in) {
|
|||
#define MS_SUBQ_F16 vsubq_f16
|
||||
#define MS_MULQ_F16 vmulq_f16
|
||||
#define MS_FMAQ_F16 vfmaq_f16
|
||||
#define MS_MULQ_N_F16(vector, scalar) vmulq_n_f16(vector, scalar)
|
||||
#define MS_CMPGTQ_F16(src1, src2) vcgtq_f32(src1, src2)
|
||||
|
||||
static inline float16x8_t MS_TANHX8_F16(float16x8_t src) {
|
||||
float32x4_t src_low = MS_CVT_F32_F16(vget_low_f16(src));
|
||||
|
|
|
@ -44,5 +44,6 @@ int AssertCPUKernel::Run() {
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Assert, LiteKernelCreator<AssertCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Assert, LiteKernelCreator<AssertCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Assert, LiteKernelCreator<AssertCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -90,6 +90,7 @@ int ConstantOfShapeCPUKernel::Run() {
|
|||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeInt64, PrimitiveType_ConstantOfShape, LiteKernelCreator<ConstantOfShapeCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -1,50 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
|
||||
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_FORMAT_ERR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::lite::RET_PARAM_INVALID;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int DepthToSpaceBaseCPUKernel::ReSize() {
|
||||
if (in_tensors_.at(0)->format() != mindspore::NHWC) {
|
||||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
if (param_->block_size_ <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_size should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto shape_size = in_tensors_.at(0)->shape().size();
|
||||
if (shape_size != DIMENSION_4D) {
|
||||
MS_LOG(ERROR) << "Input shape size should be " << DIMENSION_4D;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
int32_t in_strides[DIMENSION_4D];
|
||||
ComputeStrides(const_cast<int *>(in_tensors_.at(0)->shape().data()), in_strides, shape_size);
|
||||
param_->in_stride_dim0_ = in_strides[0];
|
||||
param_->in_stride_dim1_ = in_strides[1];
|
||||
param_->in_stride_dim2_ = in_strides[2];
|
||||
int32_t out_strides[DIMENSION_4D];
|
||||
ComputeStrides(const_cast<int *>(out_tensors_.at(0)->shape().data()), out_strides, shape_size);
|
||||
param_->out_stride_dim0_ = out_strides[0];
|
||||
param_->out_stride_dim1_ = out_strides[1];
|
||||
param_->out_stride_dim2_ = out_strides[2];
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/inner_kernel.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/context.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
#include "nnacl/depth_to_space_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DepthToSpaceBaseCPUKernel : public InnerKernel {
|
||||
public:
|
||||
DepthToSpaceBaseCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<DepthToSpaceParameter *>(op_parameter_);
|
||||
}
|
||||
virtual ~DepthToSpaceBaseCPUKernel() = default;
|
||||
int Init() override { return lite::RET_OK; }
|
||||
int ReSize() override;
|
||||
int Run() override { return lite::RET_OK; }
|
||||
|
||||
protected:
|
||||
DepthToSpaceParameter *param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEPTH_TO_SPACE_BASE_H_
|
|
@ -41,7 +41,8 @@ int ActivationFp16CPUKernel::Init() {
|
|||
type_ != schema::ActivationType_LEAKY_RELU && type_ != schema::ActivationType_SIGMOID &&
|
||||
type_ != schema::ActivationType_TANH && type_ != schema::ActivationType_HSWISH &&
|
||||
type_ != schema::ActivationType_SWISH && type_ != schema::ActivationType_HARD_TANH &&
|
||||
type_ != schema::ActivationType_GELU && type_ != schema::ActivationType_HSIGMOID) {
|
||||
type_ != schema::ActivationType_GELU && type_ != schema::ActivationType_HSIGMOID &&
|
||||
type_ != schema::ActivationType_ELU) {
|
||||
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -84,6 +85,8 @@ int ActivationFp16CPUKernel::DoActivation(int task_id) {
|
|||
HardTanhFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, min_val_, max_val_);
|
||||
} else if (type_ == schema::ActivationType_GELU) {
|
||||
error_code = GeluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, true);
|
||||
} else if (type_ == schema::ActivationType_ELU) {
|
||||
error_code = EluFp16(fp16_input_ + stride * task_id, count, fp16_output_ + stride * task_id, alpha_);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Activation fp16 not support type: " << type_;
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp16/depth_to_space_fp16.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_DepthToSpace;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int DepthToSpaceFp16CPUKernel::Init() {
|
||||
CHECK_LESS_RETURN(in_tensors_.size(), 1);
|
||||
CHECK_LESS_RETURN(out_tensors_.size(), 1);
|
||||
param_->data_type_size_ = sizeof(float16_t);
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_DepthToSpace, LiteKernelCreator<DepthToSpaceFp16CPUKernel>)
|
||||
} // namespace mindspore::kernel
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DEPTH_TO_SPACE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DEPTH_TO_SPACE_FP16_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/runtime/kernel/arm/fp32/depth_to_space_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DepthToSpaceFp16CPUKernel : public DepthToSpaceCPUKernel {
|
||||
public:
|
||||
DepthToSpaceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: DepthToSpaceCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~DepthToSpaceFp16CPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_DEPTH_TO_SPACE_FP16_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,6 +16,8 @@
|
|||
#include "src/runtime/kernel/arm/fp32/depth_to_space_fp32.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "nnacl/base/depth_to_space_base.h"
|
||||
#include "nnacl/nnacl_common.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
|
@ -35,13 +37,38 @@ int DepthToSpaceCPUKernel::Init() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceCPUKernel::ReSize() { return DepthToSpaceBaseCPUKernel::ReSize(); }
|
||||
int DepthToSpaceCPUKernel::ReSize() {
|
||||
if (in_tensors_[0]->format() != mindspore::NHWC) {
|
||||
MS_LOG(ERROR) << "depth_to_space only support NHWC now!";
|
||||
return RET_FORMAT_ERR;
|
||||
}
|
||||
if (param_->block_size_ <= 0) {
|
||||
MS_LOG(ERROR) << "Input block_size should > 0!";
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
auto shape_size = in_tensors_[0]->shape().size();
|
||||
if (shape_size != DIMENSION_4D) {
|
||||
MS_LOG(ERROR) << "Input shape size should be " << DIMENSION_4D;
|
||||
return RET_PARAM_INVALID;
|
||||
}
|
||||
int32_t in_strides[DIMENSION_4D];
|
||||
ComputeStrides(const_cast<int *>(in_tensors_[0]->shape().data()), in_strides, shape_size);
|
||||
param_->in_stride_dim0_ = in_strides[0];
|
||||
param_->in_stride_dim1_ = in_strides[1];
|
||||
param_->in_stride_dim2_ = in_strides[2];
|
||||
int32_t out_strides[DIMENSION_4D];
|
||||
ComputeStrides(const_cast<int *>(out_tensors_[0]->shape().data()), out_strides, shape_size);
|
||||
param_->out_stride_dim0_ = out_strides[0];
|
||||
param_->out_stride_dim1_ = out_strides[1];
|
||||
param_->out_stride_dim2_ = out_strides[2];
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int DepthToSpaceCPUKernel::Run() {
|
||||
auto input = in_tensors_[0];
|
||||
auto output = out_tensors_[0];
|
||||
const float *input_data = reinterpret_cast<const float *>(input->data_c());
|
||||
float *output_data = reinterpret_cast<float *>(output->data_c());
|
||||
const void *input_data = input->data_c();
|
||||
void *output_data = output->data_c();
|
||||
auto in_shape = input->shape();
|
||||
MS_CHECK_TRUE_MSG(in_shape.size() == DIMENSION_4D, RET_ERROR, "input shape should be 4!");
|
||||
if (input->format() == mindspore::NHWC) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,21 +17,25 @@
|
|||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DEPTH_TO_SPACE_H_
|
||||
|
||||
#include <vector>
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/base/depth_to_space_base.h"
|
||||
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
|
||||
#include "src/inner_kernel.h"
|
||||
#include "nnacl/depth_to_space_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DepthToSpaceCPUKernel : public DepthToSpaceBaseCPUKernel {
|
||||
class DepthToSpaceCPUKernel : public InnerKernel {
|
||||
public:
|
||||
DepthToSpaceCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
: InnerKernel(parameter, inputs, outputs, ctx) {
|
||||
param_ = reinterpret_cast<DepthToSpaceParameter *>(op_parameter_);
|
||||
}
|
||||
~DepthToSpaceCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
protected:
|
||||
DepthToSpaceParameter *param_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
|
|
|
@ -56,16 +56,13 @@ int InvertPermutationCPUKernel::Run() {
|
|||
}
|
||||
auto input_ptr = reinterpret_cast<int32_t *>(in_tensor->data_c());
|
||||
auto output_ptr = reinterpret_cast<int32_t *>(out_tensor->data_c());
|
||||
CHECK_NULL_RETURN(out_tensor->data_c());
|
||||
CHECK_NULL_RETURN(in_tensor->data_c());
|
||||
if (input_ptr == nullptr || output_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "null pointer dereferencing.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
CHECK_NULL_RETURN(input_ptr);
|
||||
CHECK_NULL_RETURN(output_ptr);
|
||||
InvertPermutation(input_ptr, output_ptr, in_tensors_[0]->ElementsNum());
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_InvertPermutation, LiteKernelCreator<InvertPermutationCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InvertPermutation, LiteKernelCreator<InvertPermutationCPUKernel>)
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_InvertPermutation, LiteKernelCreator<InvertPermutationCPUKernel>)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -66,8 +66,6 @@ int DepthToSpaceInt8CPUKernel::Init() {
|
|||
return ReSize();
|
||||
}
|
||||
|
||||
int DepthToSpaceInt8CPUKernel::ReSize() { return DepthToSpaceBaseCPUKernel::ReSize(); }
|
||||
|
||||
int DepthToSpaceInt8CPUKernel::Run() {
|
||||
auto input = in_tensors_[0];
|
||||
auto output = out_tensors_[0];
|
||||
|
|
|
@ -21,18 +21,17 @@
|
|||
#include "nnacl/base/depth_to_space_base.h"
|
||||
#include "nnacl/int8/depth_to_space_int8.h"
|
||||
#include "nnacl/int8/quantize.h"
|
||||
#include "src/runtime/kernel/arm/base/depth_to_space_base.h"
|
||||
#include "src/runtime/kernel/arm/fp32/depth_to_space_fp32.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class DepthToSpaceInt8CPUKernel : public DepthToSpaceBaseCPUKernel {
|
||||
class DepthToSpaceInt8CPUKernel : public DepthToSpaceCPUKernel {
|
||||
public:
|
||||
DepthToSpaceInt8CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: DepthToSpaceBaseCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
: DepthToSpaceCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~DepthToSpaceInt8CPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
|
|
Loading…
Reference in New Issue