!23417 [MS][LITE][CPU]nchw

Merge pull request !23417 from liuzhongkai/code_rede
This commit is contained in:
i-robot 2021-09-15 06:10:59 +00:00 committed by Gitee
commit b8ace0705f
4 changed files with 34 additions and 10 deletions

View File

@ -90,8 +90,13 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func); out_func);
} else { } else {
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param, WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func); out_func);
#else
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
out_func);
#endif
} }
} }
} }

View File

@ -18,6 +18,7 @@
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "include/errorcode.h" #include "include/errorcode.h"
#include "nnacl/fp32/instance_norm_fp32.h" #include "nnacl/fp32/instance_norm_fp32.h"
#include "nnacl/fp32/pack_fp32.h"
using mindspore::kernel::KERNEL_ARCH; using mindspore::kernel::KERNEL_ARCH;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
@ -46,14 +47,14 @@ int InstanceNormCPUKernel::ReSize() {
int InstanceNormCPUKernel::DoInstanceNorm(int task_id) { int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
int ret = 0; int ret = 0;
if (in_tensors_[0]->format() == NC4HW4) { if (in_tensors_[0]->format() == NC4HW4) { // arm64 x86-avx x86-sse x86
#ifdef ENABLE_AVX #ifdef ENABLE_AVX
ret = InstanceNormNC8HW8(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id); ret = InstanceNormNC8HW8(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
#else #else
ret = InstanceNormNC4HW4(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id); ret = InstanceNormNC4HW4(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
#endif #endif
} else { } else {
ret = InstanceNorm(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id); ret = InstanceNorm(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
} }
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]"; MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
@ -81,12 +82,27 @@ int InstanceNormCPUKernel::Run() {
CHECK_NULL_RETURN(gamma_data_); CHECK_NULL_RETURN(gamma_data_);
CHECK_NULL_RETURN(beta_data_); CHECK_NULL_RETURN(beta_data_);
CHECK_NULL_RETURN(dst_data_); CHECK_NULL_RETURN(dst_data_);
if (in_tensors_[0]->format() == NC4HW4) {
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
tmp_src_data_ = src_data_;
#else // other platform is not support nc4hw4 and must be pack to nc4hw4
tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
CHECK_NULL_RETURN(tmp_src_data_);
PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
#endif
} else {
tmp_src_data_ = src_data_;
}
auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_); auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]"; MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
return ret;
} }
return RET_OK; if (in_tensors_[0]->format() == NC4HW4) {
#if (!defined(ENABLE_AVX) && !defined(ENABLE_ARM64))
FreeTmpBuffer();
#endif
}
return ret;
} }
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>) REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)

View File

@ -36,10 +36,17 @@ class InstanceNormCPUKernel : public InnerKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int DoInstanceNorm(int task_id); int DoInstanceNorm(int task_id);
void FreeTmpBuffer() {
if (tmp_src_data_ != nullptr) {
ms_context_->allocator->Free(tmp_src_data_);
tmp_src_data_ = nullptr;
}
}
private: private:
InstanceNormParameter *param_ = nullptr; InstanceNormParameter *param_ = nullptr;
float *src_data_ = nullptr; float *src_data_ = nullptr;
float *tmp_src_data_ = nullptr;
float *dst_data_ = nullptr; float *dst_data_ = nullptr;
float *gamma_data_ = nullptr; float *gamma_data_ = nullptr;
float *beta_data_ = nullptr; float *beta_data_ = nullptr;

View File

@ -145,11 +145,7 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne
} }
} }
#if defined(ENABLE_ARM64)
return true; return true;
#endif
return false;
} }
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) { void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {