forked from mindspore-Ecosystem/mindspore
!11017 [MS_LITE] fix conv bn pass
From: @YeFeng_24 Reviewed-by: @hangangqiang Signed-off-by: @hangangqiang
This commit is contained in:
commit
15a22f6911
|
@ -111,8 +111,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
||||||
// remove quantdtype when awaretraining
|
// remove quantdtype when awaretraining
|
||||||
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
|
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
||||||
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
|
conv_bn_pass->SetFmkType(config->fmk);
|
||||||
|
fusion_pm->AddPass(conv_bn_pass);
|
||||||
|
auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>();
|
||||||
|
conv_scale_pass->SetFmkType(config->fmk);
|
||||||
|
fusion_pm->AddPass(conv_scale_pass);
|
||||||
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
||||||
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
||||||
|
|
|
@ -36,7 +36,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
MS_LOG(ERROR) << "primitive is nullptr";
|
MS_LOG(ERROR) << "primitive is nullptr";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
}
|
}
|
||||||
auto attr = std::make_unique<schema::BatchNormT>();
|
auto attr = std::make_unique<schema::FusedBatchNormT>();
|
||||||
if (attr == nullptr) {
|
if (attr == nullptr) {
|
||||||
MS_LOG(ERROR) << "new op failed";
|
MS_LOG(ERROR) << "new op failed";
|
||||||
return RET_NULL_PTR;
|
return RET_NULL_PTR;
|
||||||
|
@ -45,7 +45,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
|
||||||
TensorFlowUtils::FindAttrValue(tf_op, "epsilon", &attr_value);
|
TensorFlowUtils::FindAttrValue(tf_op, "epsilon", &attr_value);
|
||||||
attr->epsilon = attr_value.f();
|
attr->epsilon = attr_value.f();
|
||||||
|
|
||||||
primitive->value.type = schema::PrimitiveType_BatchNorm;
|
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||||
primitive->value.value = attr.release();
|
primitive->value.value = attr.release();
|
||||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||||
if (*primitiveC == nullptr) {
|
if (*primitiveC == nullptr) {
|
||||||
|
|
|
@ -202,10 +202,15 @@ void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num,
|
||||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (this->fmk_type_ == lite::converter::FmkType_TF) {
|
||||||
for (int i = 0; i < kernel_num; i++) {
|
for (int i = 0; i < kernel_num * kernel_size; i++) {
|
||||||
for (int j = 0; j < kernel_size; j++) {
|
tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num];
|
||||||
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < kernel_num; i++) {
|
||||||
|
for (int j = 0; j < kernel_size; j++) {
|
||||||
|
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,9 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include "backend/optimizer/common/optimizer.h"
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
#include "tools/converter/converter_flags.h"
|
||||||
|
|
||||||
|
using mindspore::lite::converter::FmkType;
|
||||||
namespace mindspore::opt {
|
namespace mindspore::opt {
|
||||||
class ConvTransformFusion : public PatternProcessPass {
|
class ConvTransformFusion : public PatternProcessPass {
|
||||||
public:
|
public:
|
||||||
|
@ -32,6 +34,10 @@ class ConvTransformFusion : public PatternProcessPass {
|
||||||
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
|
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
|
||||||
void CalNewWeightTensor(float *, int, int, const float *) const;
|
void CalNewWeightTensor(float *, int, int, const float *) const;
|
||||||
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
|
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
|
||||||
|
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
FmkType fmk_type_ = lite::converter::FmkType_TF;
|
||||||
};
|
};
|
||||||
} // namespace mindspore::opt
|
} // namespace mindspore::opt
|
||||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||||
|
|
Loading…
Reference in New Issue