This commit is contained in:
lzk 2021-09-25 20:01:03 -07:00
parent 24ebb966e9
commit 7bf3faa1c2
9 changed files with 144 additions and 6 deletions

View File

@ -164,6 +164,10 @@ void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int
void PackNHWCToNC8HW8NotAlignedFp16(const float16_t *src, float16_t *dst, const int batch, const int plane,
const int channel) {
if (channel <= C8NUM) {
memcpy(dst, src, batch * plane * channel * sizeof(float16_t));
return;
}
int tmp = DOWN_DIV(channel, C8NUM);
int c_res = channel - tmp * C8NUM;
int c8_block = tmp * plane * C8NUM;

View File

@ -29,6 +29,68 @@ void PackHWCToWHC(const float *src, float *dst, int height, int width, int chann
}
}
void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
if (channel <= C4NUM) {
memcpy(dst, src, batch * plane * channel * sizeof(float));
return;
}
int tmp = DOWN_DIV(channel, C4NUM);
int c_res = channel - tmp * C4NUM;
int c4_block = tmp * plane * C4NUM;
for (int b = 0; b < batch; b++) {
int batch_oc_offset = b * plane * channel;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = batch_oc_offset + k * channel;
int dst_kernel_offset = batch_oc_offset + k * C4NUM;
int c = 0;
for (; c <= channel - C4NUM; c += C4NUM) {
#if defined(ENABLE_SSE) || defined(ENABLE_ARM)
MS_FLOAT32X4 src_data = MS_LDQ_F32(src + src_kernel_offset + c);
MS_STQ_F32(dst + dst_kernel_offset + c * plane, src_data);
#else
for (int k1 = 0; k1 < C4NUM; ++k1) {
(dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
}
#endif
}
for (; c < channel; ++c) {
dst[batch_oc_offset + c4_block + k * c_res + c - tmp * C4NUM] = src[src_kernel_offset + c];
}
}
}
}
void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel) {
if (channel <= C8NUM) {
memcpy(dst, src, batch * plane * channel * sizeof(float));
return;
}
int tmp = DOWN_DIV(channel, C8NUM);
int c_res = channel - tmp * C8NUM;
int c8_block = tmp * plane * C8NUM;
for (int b = 0; b < batch; b++) {
int batch_oc_offset = b * plane * channel;
for (int k = 0; k < plane; k++) {
int src_kernel_offset = batch_oc_offset + k * channel;
int dst_kernel_offset = batch_oc_offset + k * C8NUM;
int c = 0;
for (; c <= channel - C8NUM; c += C8NUM) {
#ifdef ENABLE_AVX
MS_FLOAT32X8 src_data = MS_LD256_F32(src + src_kernel_offset + c);
MS_ST256_F32(dst + dst_kernel_offset + c * plane, src_data);
#else
for (int k1 = 0; k1 < C8NUM; ++k1) {
(dst + dst_kernel_offset + c * plane)[k1] = (src + src_kernel_offset + c)[k1];
}
#endif
}
for (; c < channel; ++c) {
dst[batch_oc_offset + c8_block + k * c_res + c - tmp * C8NUM] = src[src_kernel_offset + c];
}
}
}
}
void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num,
int block_index) {
// input format : nhwc

View File

@ -41,6 +41,8 @@ void PackNC8HW8ToNHWCFp32(const void *src, void *dst, int batch, int plane, int
void PackNC8HW8AlignedToNC8HW8NotAlignedFp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel);
void PackNHWCToCXHWNXFp32(const float *src, float *dst, int batch, int plane, int channel);
void PackNHWCToNC4HW4NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel);
void PackNHWCToNC8HW8NotAlignedFp32(const float *src, float *dst, const int batch, const int plane, const int channel);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackDepthwiseIndirectWeightC4Fp32(const void *src, void *dst, int height, int width, int channel);

View File

@ -458,7 +458,6 @@ REG_INFER(Erf, PrimType_Erf, CommonInferShape)
REG_INFER(Exp, PrimType_ExpFusion, CommonInferShape)
REG_INFER(FakeQuantWithMinMaxVars, PrimType_FakeQuantWithMinMaxVars, CommonInferShape)
REG_INFER(Floor, PrimType_Floor, CommonInferShape)
REG_INFER(InstanceNorm, PrimType_InstanceNorm, CommonInferShape)
REG_INFER(IsFinite, PrimType_IsFinite, CommonInferShape)
REG_INFER(LeakyRelu, PrimType_LeakyRelu, CommonInferShape)
REG_INFER(Log, PrimType_Log, CommonInferShape)

View File

@ -79,8 +79,9 @@ int Conv2dInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC *
}
const TensorC *weight_tensor = inputs[1];
TensorC *out_tensor = outputs[0];
out_tensor->format_ = input_tensor->format_;
if (out_tensor->format_ != Format_NC4HW4) {
out_tensor->format_ = input_tensor->format_;
}
out_tensor->data_type_ = input_tensor->data_type_;
ConvParameter *param = (ConvParameter *)parameter;
if (param->group_ == 0) {

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_NNACL_INSTANCE_NORM_INFER_H
#define MINDSPORE_NNACL_INSTANCE_NORM_INFER_H
#include "nnacl/infer/common_infer.h"
#ifdef __cplusplus
extern "C" {
#endif
int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_NNACL_INSTANCE_NORM_INFER_H

View File

@ -0,0 +1,35 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "nnacl/infer/crop_infer.h"
#include "nnacl/infer/infer_register.h"
int InstanceNormInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC **outputs, size_t outputs_size,
OpParameter *parameter) {
if (parameter == NULL || inputs[0] == NULL || outputs[0] == NULL) {
return NNACL_NULL_PTR;
}
SetDataTypeFormat(outputs[0], inputs[0]);
if (outputs[0]->format_ == Format_NC4HW4) {
outputs[0]->format_ = Format_NHWC;
}
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
SetShapeTensor(outputs[0], inputs[0]);
return NNACL_OK;
}
REG_INFER(InstanceNorm, PrimType_InstanceNorm, InstanceNormInferShape)

View File

@ -31,7 +31,7 @@ int OutputTensor2TensorC(const std::vector<lite::Tensor *> &tensors, std::vector
return RET_ERROR;
}
tensor_c->data_type_ = kNumberTypeFloat32;
tensor_c->format_ = mindspore::NCHW;
tensor_c->format_ = tensors[i]->format();
tensor_c->data_ = nullptr;
tensor_c->shape_size_ = 0;
tensors_c->push_back(tensor_c);

View File

@ -88,12 +88,16 @@ int InstanceNormCPUKernel::Run() {
#else // other platform is not support nc4hw4 and must be pack to nc4hw4
tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
CHECK_NULL_RETURN(tmp_src_data_);
PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
PackNHWCToNC4HW4NotAlignedFp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
#endif
} else if (in_tensors_[0]->format() == NHWC) {
tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
CHECK_NULL_RETURN(tmp_src_data_);
PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
#ifdef ENABLE_AVX
PackNHWCToNC8HW8NotAlignedFp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
#else
PackNHWCToNC4HW4NotAlignedFp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
#endif
in_tensors_[0]->set_format(NC4HW4);
} else {
tmp_src_data_ = src_data_;