!11017 [MS_LITE] fix conv bn pass

From: @YeFeng_24
Reviewed-by: @hangangqiang
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2021-01-07 09:38:05 +08:00 committed by Gitee
commit 15a22f6911
4 changed files with 23 additions and 8 deletions

View File

@ -111,8 +111,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
// remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
conv_bn_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(conv_bn_pass);
auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>();
conv_scale_pass->SetFmkType(config->fmk);
fusion_pm->AddPass(conv_scale_pass);
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());

View File

@ -36,7 +36,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchNormT>();
auto attr = std::make_unique<schema::FusedBatchNormT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
@ -45,7 +45,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
TensorFlowUtils::FindAttrValue(tf_op, "epsilon", &attr_value);
attr->epsilon = attr_value.f();
primitive->value.type = schema::PrimitiveType_BatchNorm;
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {

View File

@ -202,10 +202,15 @@ void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num,
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
return;
}
for (int i = 0; i < kernel_num; i++) {
for (int j = 0; j < kernel_size; j++) {
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
if (this->fmk_type_ == lite::converter::FmkType_TF) {
for (int i = 0; i < kernel_num * kernel_size; i++) {
tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num];
}
} else {
for (int i = 0; i < kernel_num; i++) {
for (int j = 0; j < kernel_size; j++) {
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
}
}
}

View File

@ -19,7 +19,9 @@
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class ConvTransformFusion : public PatternProcessPass {
public:
@ -32,6 +34,10 @@ class ConvTransformFusion : public PatternProcessPass {
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
void CalNewWeightTensor(float *, int, int, const float *) const;
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
private:
FmkType fmk_type_ = lite::converter::FmkType_TF;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_