forked from mindspore-Ecosystem/mindspore
!23417 [MS][LITE][CPU]nchw
Merge pull request !23417 from liuzhongkai/code_rede
This commit is contained in:
commit
b8ace0705f
|
@ -90,8 +90,13 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
||||||
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
out_func);
|
out_func);
|
||||||
} else {
|
} else {
|
||||||
|
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||||
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
out_func);
|
out_func);
|
||||||
|
#else
|
||||||
|
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||||
|
out_func);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "src/kernel_registry.h"
|
#include "src/kernel_registry.h"
|
||||||
#include "include/errorcode.h"
|
#include "include/errorcode.h"
|
||||||
#include "nnacl/fp32/instance_norm_fp32.h"
|
#include "nnacl/fp32/instance_norm_fp32.h"
|
||||||
|
#include "nnacl/fp32/pack_fp32.h"
|
||||||
|
|
||||||
using mindspore::kernel::KERNEL_ARCH;
|
using mindspore::kernel::KERNEL_ARCH;
|
||||||
using mindspore::lite::KernelRegistrar;
|
using mindspore::lite::KernelRegistrar;
|
||||||
|
@ -46,14 +47,14 @@ int InstanceNormCPUKernel::ReSize() {
|
||||||
|
|
||||||
int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
|
int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
if (in_tensors_[0]->format() == NC4HW4) {
|
if (in_tensors_[0]->format() == NC4HW4) { // arm64 x86-avx x86-sse x86
|
||||||
#ifdef ENABLE_AVX
|
#ifdef ENABLE_AVX
|
||||||
ret = InstanceNormNC8HW8(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
ret = InstanceNormNC8HW8(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||||
#else
|
#else
|
||||||
ret = InstanceNormNC4HW4(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
ret = InstanceNormNC4HW4(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
ret = InstanceNorm(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
ret = InstanceNorm(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||||
}
|
}
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
|
MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
|
||||||
|
@ -81,12 +82,27 @@ int InstanceNormCPUKernel::Run() {
|
||||||
CHECK_NULL_RETURN(gamma_data_);
|
CHECK_NULL_RETURN(gamma_data_);
|
||||||
CHECK_NULL_RETURN(beta_data_);
|
CHECK_NULL_RETURN(beta_data_);
|
||||||
CHECK_NULL_RETURN(dst_data_);
|
CHECK_NULL_RETURN(dst_data_);
|
||||||
|
if (in_tensors_[0]->format() == NC4HW4) {
|
||||||
|
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||||
|
tmp_src_data_ = src_data_;
|
||||||
|
#else // other platform is not support nc4hw4 and must be pack to nc4hw4
|
||||||
|
tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
|
||||||
|
CHECK_NULL_RETURN(tmp_src_data_);
|
||||||
|
PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
tmp_src_data_ = src_data_;
|
||||||
|
}
|
||||||
auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
|
auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
|
||||||
if (ret != RET_OK) {
|
if (ret != RET_OK) {
|
||||||
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
|
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
return RET_OK;
|
if (in_tensors_[0]->format() == NC4HW4) {
|
||||||
|
#if (!defined(ENABLE_AVX) && !defined(ENABLE_ARM64))
|
||||||
|
FreeTmpBuffer();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)
|
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)
|
||||||
|
|
|
@ -36,10 +36,17 @@ class InstanceNormCPUKernel : public InnerKernel {
|
||||||
int ReSize() override;
|
int ReSize() override;
|
||||||
int Run() override;
|
int Run() override;
|
||||||
int DoInstanceNorm(int task_id);
|
int DoInstanceNorm(int task_id);
|
||||||
|
void FreeTmpBuffer() {
|
||||||
|
if (tmp_src_data_ != nullptr) {
|
||||||
|
ms_context_->allocator->Free(tmp_src_data_);
|
||||||
|
tmp_src_data_ = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InstanceNormParameter *param_ = nullptr;
|
InstanceNormParameter *param_ = nullptr;
|
||||||
float *src_data_ = nullptr;
|
float *src_data_ = nullptr;
|
||||||
|
float *tmp_src_data_ = nullptr;
|
||||||
float *dst_data_ = nullptr;
|
float *dst_data_ = nullptr;
|
||||||
float *gamma_data_ = nullptr;
|
float *gamma_data_ = nullptr;
|
||||||
float *beta_data_ = nullptr;
|
float *beta_data_ = nullptr;
|
||||||
|
|
|
@ -145,11 +145,7 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#if defined(ENABLE_ARM64)
|
|
||||||
return true;
|
return true;
|
||||||
#endif
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
|
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
|
||||||
|
|
Loading…
Reference in New Issue