From 86ae625dc7aa6a96717b8d547424ea1fa1abcf1c Mon Sep 17 00:00:00 2001 From: zengxianglong Date: Mon, 26 Oct 2020 09:12:18 +0800 Subject: [PATCH] fix bugs in instance_norm and pad operator --- mindspore/lite/src/ops/instance_norm.cc | 9 ++++++ mindspore/lite/src/ops/pad.cc | 3 +- .../ops/populate/instance_norm_populate.cc | 4 +-- .../lite/src/ops/populate/pad_populate.cc | 30 +++++++++---------- mindspore/lite/src/ops/primitive_c.cc | 3 ++ .../lite/src/runtime/kernel/arm/fp32/pad.cc | 17 ++++++++--- mindspore/lite/tools/common/node_util.cc | 3 +- 7 files changed, 44 insertions(+), 25 deletions(-) diff --git a/mindspore/lite/src/ops/instance_norm.cc b/mindspore/lite/src/ops/instance_norm.cc index 7a5d2b90b2b..41398e3543d 100644 --- a/mindspore/lite/src/ops/instance_norm.cc +++ b/mindspore/lite/src/ops/instance_norm.cc @@ -16,6 +16,11 @@ #include "src/ops/instance_norm.h" #include + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE @@ -60,6 +65,10 @@ int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu } float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); } +PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator); #endif } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/pad.cc b/mindspore/lite/src/ops/pad.cc index 0a07484221e..86d01f18a98 100644 --- a/mindspore/lite/src/ops/pad.cc +++ b/mindspore/lite/src/ops/pad.cc @@ -87,11 +87,10 @@ int Pad::InferShape(std::vector inputs, std::vector outputs) } std::vector paddings; - if (GetPaddingMode() == static_cast(schema::PaddingMode_CONSTANT)) { + if (inputs.size() == 1) { paddings = GetPaddings(); } else { // mirror pad - MS_ASSERT(inputs.size() == 2); auto paddings_tensor = inputs.at(1); int rank = static_cast(inputs.front()->shape().size()); MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank); diff --git a/mindspore/lite/src/ops/populate/instance_norm_populate.cc b/mindspore/lite/src/ops/populate/instance_norm_populate.cc index 91f9450c5b4..d333d75ee3b 100644 --- a/mindspore/lite/src/ops/populate/instance_norm_populate.cc +++ b/mindspore/lite/src/ops/populate/instance_norm_populate.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace lite { -OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) { +OpParameter *PopulateInstanceNormParameter(const mindspore::lite::PrimitiveC *primitive) { const auto param = reinterpret_cast(const_cast(primitive)); InstanceNormParameter *instance_norm_param = @@ -37,6 +37,6 @@ OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) return reinterpret_cast(instance_norm_param); } -Registry InstanceNormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateInstanceNorm); +Registry InstanceNormParameterRegistry(schema::PrimitiveType_InstanceNorm, PopulateInstanceNormParameter); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/src/ops/populate/pad_populate.cc b/mindspore/lite/src/ops/populate/pad_populate.cc index bc5b9eb2a49..d9f19f20e20 100644 --- a/mindspore/lite/src/ops/populate/pad_populate.cc +++ b/mindspore/lite/src/ops/populate/pad_populate.cc @@ -32,24 +32,22 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive) pad_param->op_parameter_.type_ = primitive->Type(); auto pad_node = reinterpret_cast(const_cast(primitive)); pad_param->pad_mode_ = pad_node->GetPaddingMode(); - if (pad_param->pad_mode_ == static_cast(schema::PaddingMode_CONSTANT)) { - pad_param->constant_value_ = pad_node->GetConstantValue(); - auto size = pad_node->GetPaddings().size(); - if (size > MAX_PAD_SIZE) { - MS_LOG(ERROR) << "Invalid padding size: " << size; - free(pad_param); - return nullptr; - } - - for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) { - pad_param->paddings_[i] = 0; - } - for (size_t i = 0; i < size; i++) { - pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i]; - } - pad_param->padding_length = MAX_PAD_SIZE; + pad_param->constant_value_ = pad_node->GetConstantValue(); + auto size = pad_node->GetPaddings().size(); + if (size > MAX_PAD_SIZE) { + MS_LOG(ERROR) << "Invalid padding size: " << size; + free(pad_param); + return nullptr; } + for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) { + pad_param->paddings_[i] = 0; + } + for (size_t i = 0; i < size; i++) { + pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i]; + } + pad_param->padding_length = MAX_PAD_SIZE; + return reinterpret_cast(pad_param); } Registry PadParameterRegistry(schema::PrimitiveType_Pad, PopulatePadParameter); diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 8f04785bd15..d3b4889aa4f 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -143,6 +143,7 @@ #include "src/ops/audio_spectrogram.h" #include "src/ops/mfcc.h" #include "src/ops/identity.h" +#include "src/ops/instance_norm.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -790,6 +791,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new AudioSpectrogram(primitive); case schema::PrimitiveType_Mfcc: return new Mfcc(primitive); + case schema::PrimitiveType_InstanceNorm: + return new InstanceNorm(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc index a8eede6ddf2..f2dac28faaa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pad.cc @@ -213,11 +213,20 @@ void PadCPUKernel::CalculateStrides() { } int PadCPUKernel::HandleMirrorPad() { - auto ret = CopyPaddingFromInput(); - if (ret != RET_OK) { - return ret; + if (in_tensors_.size() == 1) { + auto input_shape = in_tensors_.at(0)->shape(); + int rank = static_cast(input_shape.size()); + auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank); + if (ret != RET_OK) { + return ret; + } + } else { + auto ret = CopyPaddingFromInput(); + if (ret != RET_OK) { + return ret; + } } - ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_); + auto ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_); if (ret != RET_OK) { return ret; } diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 81c77aee732..f6775e8b09a 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -46,7 +46,8 @@ static const std::vector nhwcOpList = { schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, schema::PrimitiveType_PReLU, - schema::PrimitiveType_BiasAdd}; + schema::PrimitiveType_BiasAdd, + schema::PrimitiveType_InstanceNorm}; static const std::vector nhwcOpDualInputList = { #ifdef SUPPORT_TRAIN