From b07d7150cf199b7a628c342076c33f446734acd3 Mon Sep 17 00:00:00 2001 From: yefeng Date: Wed, 6 Jan 2021 16:39:32 +0800 Subject: [PATCH] 033-conv_bn_fusion_pass-5 --- mindspore/lite/tools/converter/anf_transform.cc | 8 ++++++-- .../converter/parser/tf/tf_batchnorm_parser.cc | 4 ++-- .../tools/optimizer/fusion/conv_transform_fusion.cc | 13 +++++++++---- .../tools/optimizer/fusion/conv_transform_fusion.h | 6 ++++++ 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index 35acda85ccc..b4743fb6992 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -111,8 +111,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap // remove quantdtype when awaretraining fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared()); - fusion_pm->AddPass(std::make_shared()); + auto conv_bn_pass = std::make_shared(); + conv_bn_pass->SetFmkType(config->fmk); + fusion_pm->AddPass(conv_bn_pass); + auto conv_scale_pass = std::make_shared(); + conv_scale_pass->SetFmkType(config->fmk); + fusion_pm->AddPass(conv_scale_pass); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc index b13f805eacf..a228ff72289 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_batchnorm_parser.cc @@ -36,7 +36,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "primitive is nullptr"; return RET_NULL_PTR; } - auto attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; @@ -45,7 +45,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op, TensorFlowUtils::FindAttrValue(tf_op, "epsilon", &attr_value); attr->epsilon = attr_value.f(); - primitive->value.type = schema::PrimitiveType_BatchNorm; + primitive->value.type = schema::PrimitiveType_FusedBatchNorm; primitive->value.value = attr.release(); *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 66af20806c8..1935d7fbe2a 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -202,10 +202,15 @@ void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED); return; } - - for (int i = 0; i < kernel_num; i++) { - for (int j = 0; j < kernel_size; j++) { - tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i]; + if (this->fmk_type_ == lite::converter::FmkType_TF) { + for (int i = 0; i < kernel_num * kernel_size; i++) { + tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num]; + } + } else { + for (int i = 0; i < kernel_num; i++) { + for (int j = 0; j < kernel_size; j++) { + tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i]; + } } } diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h index 04ba3e5f917..379edf9315c 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.h @@ -19,7 +19,9 @@ #include #include "backend/optimizer/common/optimizer.h" +#include "tools/converter/converter_flags.h" +using mindspore::lite::converter::FmkType; namespace mindspore::opt { class ConvTransformFusion : public PatternProcessPass { public: @@ -32,6 +34,10 @@ class ConvTransformFusion : public PatternProcessPass { void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const; void CalNewWeightTensor(float *, int, int, const float *) const; void CalNewBiasTensor(float *, int, bool, const float *, const float *) const; + void SetFmkType(FmkType type) { this->fmk_type_ = type; } + + private: + FmkType fmk_type_ = lite::converter::FmkType_TF; }; } // namespace mindspore::opt #endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_