forked from mindspore-Ecosystem/mindspore
!7733 fix bugs in instance_norm and pad operator
Merge pull request !7733 from XianglongZeng/myms_code
This commit is contained in:
commit
fa963bb37a
|
@ -16,6 +16,11 @@
|
|||
|
||||
#include "src/ops/instance_norm.h"
|
||||
#include <memory>
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
|
@ -60,6 +65,10 @@ int InstanceNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbu
|
|||
}
|
||||
float InstanceNorm::GetEpsilon() const { return this->primitive_->value_as_InstanceNorm()->epsilon(); }
|
||||
|
||||
PrimitiveC *InstanceNormCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<InstanceNorm>(primitive);
|
||||
}
|
||||
Registry InstanceNormRegistry(schema::PrimitiveType_InstanceNorm, InstanceNormCreator);
|
||||
#endif
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -87,11 +87,10 @@ int Pad::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs)
|
|||
}
|
||||
|
||||
std::vector<int> paddings;
|
||||
if (GetPaddingMode() == static_cast<int>(schema::PaddingMode_CONSTANT)) {
|
||||
if (inputs.size() == 1) {
|
||||
paddings = GetPaddings();
|
||||
} else {
|
||||
// mirror pad
|
||||
MS_ASSERT(inputs.size() == 2);
|
||||
auto paddings_tensor = inputs.at(1);
|
||||
int rank = static_cast<int>(inputs.front()->shape().size());
|
||||
MS_ASSERT(paddings_tensor->ElementsNum() == 2 * rank);
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive) {
|
||||
OpParameter *PopulateInstanceNormParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
const auto param =
|
||||
reinterpret_cast<mindspore::lite::InstanceNorm *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
InstanceNormParameter *instance_norm_param =
|
||||
|
@ -37,6 +37,6 @@ OpParameter *PopulateInstanceNorm(const mindspore::lite::PrimitiveC *primitive)
|
|||
return reinterpret_cast<OpParameter *>(instance_norm_param);
|
||||
}
|
||||
|
||||
Registry InstanceNormParameterRegistry(schema::PrimitiveType_L2Norm, PopulateInstanceNorm);
|
||||
Registry InstanceNormParameterRegistry(schema::PrimitiveType_InstanceNorm, PopulateInstanceNormParameter);
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,24 +32,22 @@ OpParameter *PopulatePadParameter(const mindspore::lite::PrimitiveC *primitive)
|
|||
pad_param->op_parameter_.type_ = primitive->Type();
|
||||
auto pad_node = reinterpret_cast<mindspore::lite::Pad *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
pad_param->pad_mode_ = pad_node->GetPaddingMode();
|
||||
if (pad_param->pad_mode_ == static_cast<int>(schema::PaddingMode_CONSTANT)) {
|
||||
pad_param->constant_value_ = pad_node->GetConstantValue();
|
||||
auto size = pad_node->GetPaddings().size();
|
||||
if (size > MAX_PAD_SIZE) {
|
||||
MS_LOG(ERROR) << "Invalid padding size: " << size;
|
||||
free(pad_param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
|
||||
pad_param->paddings_[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
|
||||
}
|
||||
pad_param->padding_length = MAX_PAD_SIZE;
|
||||
pad_param->constant_value_ = pad_node->GetConstantValue();
|
||||
auto size = pad_node->GetPaddings().size();
|
||||
if (size > MAX_PAD_SIZE) {
|
||||
MS_LOG(ERROR) << "Invalid padding size: " << size;
|
||||
free(pad_param);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < MAX_PAD_SIZE - size; ++i) {
|
||||
pad_param->paddings_[i] = 0;
|
||||
}
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
pad_param->paddings_[MAX_PAD_SIZE - size + i] = pad_node->GetPaddings()[i];
|
||||
}
|
||||
pad_param->padding_length = MAX_PAD_SIZE;
|
||||
|
||||
return reinterpret_cast<OpParameter *>(pad_param);
|
||||
}
|
||||
Registry PadParameterRegistry(schema::PrimitiveType_Pad, PopulatePadParameter);
|
||||
|
|
|
@ -143,6 +143,7 @@
|
|||
#include "src/ops/audio_spectrogram.h"
|
||||
#include "src/ops/mfcc.h"
|
||||
#include "src/ops/identity.h"
|
||||
#include "src/ops/instance_norm.h"
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
#include "src/ops/neg_grad.h"
|
||||
|
@ -790,6 +791,8 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
|
|||
return new AudioSpectrogram(primitive);
|
||||
case schema::PrimitiveType_Mfcc:
|
||||
return new Mfcc(primitive);
|
||||
case schema::PrimitiveType_InstanceNorm:
|
||||
return new InstanceNorm(primitive);
|
||||
|
||||
#ifdef SUPPORT_TRAIN
|
||||
case schema::PrimitiveType_ActivationGrad:
|
||||
|
|
|
@ -213,11 +213,20 @@ void PadCPUKernel::CalculateStrides() {
|
|||
}
|
||||
|
||||
int PadCPUKernel::HandleMirrorPad() {
|
||||
auto ret = CopyPaddingFromInput();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
if (in_tensors_.size() == 1) {
|
||||
auto input_shape = in_tensors_.at(0)->shape();
|
||||
int rank = static_cast<int>(input_shape.size());
|
||||
auto ret = ExtendShape(in_, DEFAULT_PAD_NDIMS, input_shape.data(), rank);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
} else {
|
||||
auto ret = CopyPaddingFromInput();
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
|
||||
auto ret = CheckPaddings(pad_param_->paddings_, DEFAULT_PAD_NDIMS, in_, pad_param_->pad_mode_);
|
||||
if (ret != RET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
|
|
@ -46,7 +46,8 @@ static const std::vector<schema::PrimitiveType> nhwcOpList = {
|
|||
schema::PrimitiveType_BatchNorm,
|
||||
schema::PrimitiveType_FusedBatchNorm,
|
||||
schema::PrimitiveType_PReLU,
|
||||
schema::PrimitiveType_BiasAdd};
|
||||
schema::PrimitiveType_BiasAdd,
|
||||
schema::PrimitiveType_InstanceNorm};
|
||||
|
||||
static const std::vector<schema::PrimitiveType> nhwcOpDualInputList = {
|
||||
#ifdef SUPPORT_TRAIN
|
||||
|
|
Loading…
Reference in New Issue