forked from mindspore-Ecosystem/mindspore
!23417 [MS][LITE][CPU]nchw
Merge pull request !23417 from liuzhongkai/code_rede
This commit is contained in:
commit
b8ace0705f
|
@ -90,8 +90,13 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
|
|||
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
} else {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||
WinogradOutputNC4HW4Transform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
#else
|
||||
WinogradOutputNHWCTransform(dst_ptr, output_ptr, bias_data, cal_num, out_tile_index, out_w_block, conv_param,
|
||||
out_func);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "src/kernel_registry.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/fp32/instance_norm_fp32.h"
|
||||
#include "nnacl/fp32/pack_fp32.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
|
@ -46,14 +47,14 @@ int InstanceNormCPUKernel::ReSize() {
|
|||
|
||||
int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
|
||||
int ret = 0;
|
||||
if (in_tensors_[0]->format() == NC4HW4) {
|
||||
if (in_tensors_[0]->format() == NC4HW4) { // arm64 x86-avx x86-sse x86
|
||||
#ifdef ENABLE_AVX
|
||||
ret = InstanceNormNC8HW8(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
ret = InstanceNormNC8HW8(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
#else
|
||||
ret = InstanceNormNC4HW4(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
ret = InstanceNormNC4HW4(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
#endif
|
||||
} else {
|
||||
ret = InstanceNorm(src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
ret = InstanceNorm(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
|
||||
}
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
|
||||
|
@ -81,12 +82,27 @@ int InstanceNormCPUKernel::Run() {
|
|||
CHECK_NULL_RETURN(gamma_data_);
|
||||
CHECK_NULL_RETURN(beta_data_);
|
||||
CHECK_NULL_RETURN(dst_data_);
|
||||
if (in_tensors_[0]->format() == NC4HW4) {
|
||||
#if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
|
||||
tmp_src_data_ = src_data_;
|
||||
#else // other platform is not support nc4hw4 and must be pack to nc4hw4
|
||||
tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
|
||||
CHECK_NULL_RETURN(tmp_src_data_);
|
||||
PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
|
||||
#endif
|
||||
} else {
|
||||
tmp_src_data_ = src_data_;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
|
||||
return ret;
|
||||
}
|
||||
return RET_OK;
|
||||
if (in_tensors_[0]->format() == NC4HW4) {
|
||||
#if (!defined(ENABLE_AVX) && !defined(ENABLE_ARM64))
|
||||
FreeTmpBuffer();
|
||||
#endif
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)
|
||||
|
|
|
@ -36,10 +36,17 @@ class InstanceNormCPUKernel : public InnerKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
int DoInstanceNorm(int task_id);
|
||||
void FreeTmpBuffer() {
|
||||
if (tmp_src_data_ != nullptr) {
|
||||
ms_context_->allocator->Free(tmp_src_data_);
|
||||
tmp_src_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
InstanceNormParameter *param_ = nullptr;
|
||||
float *src_data_ = nullptr;
|
||||
float *tmp_src_data_ = nullptr;
|
||||
float *dst_data_ = nullptr;
|
||||
float *gamma_data_ = nullptr;
|
||||
float *beta_data_ = nullptr;
|
||||
|
|
|
@ -145,11 +145,7 @@ bool RuntimePassValid(const InnerContext *context, std::vector<kernel::LiteKerne
|
|||
}
|
||||
}
|
||||
|
||||
#if defined(ENABLE_ARM64)
|
||||
return true;
|
||||
#endif
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void Nc4hw4PassAct(std::vector<kernel::LiteKernel *> *kernels, std::vector<Tensor *> *tensors) {
|
||||
|
|
Loading…
Reference in New Issue