forked from mindspore-Ecosystem/mindspore
!11017 [MS_LITE] fix conv bn pass
From: @YeFeng_24 Reviewed-by: @hangangqiang Signed-off-by: @hangangqiang
This commit is contained in:
commit
15a22f6911
|
@ -111,8 +111,12 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
|
|||
// remove quantdtype when awaretraining
|
||||
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
|
||||
auto conv_bn_pass = std::make_shared<opt::ConvBatchNormFusion>();
|
||||
conv_bn_pass->SetFmkType(config->fmk);
|
||||
fusion_pm->AddPass(conv_bn_pass);
|
||||
auto conv_scale_pass = std::make_shared<opt::ConvScaleFusion>();
|
||||
conv_scale_pass->SetFmkType(config->fmk);
|
||||
fusion_pm->AddPass(conv_scale_pass);
|
||||
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
||||
|
|
|
@ -36,7 +36,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
MS_LOG(ERROR) << "primitive is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
auto attr = std::make_unique<schema::BatchNormT>();
|
||||
auto attr = std::make_unique<schema::FusedBatchNormT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -45,7 +45,7 @@ STATUS TFBatchNormParser::Parse(const tensorflow::NodeDef &tf_op,
|
|||
TensorFlowUtils::FindAttrValue(tf_op, "epsilon", &attr_value);
|
||||
attr->epsilon = attr_value.f();
|
||||
|
||||
primitive->value.type = schema::PrimitiveType_BatchNorm;
|
||||
primitive->value.type = schema::PrimitiveType_FusedBatchNorm;
|
||||
primitive->value.value = attr.release();
|
||||
*primitiveC = PrimitiveC::Create(primitive.release());
|
||||
if (*primitiveC == nullptr) {
|
||||
|
|
|
@ -202,10 +202,15 @@ void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num,
|
|||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_MEMORY_FAILED);
|
||||
return;
|
||||
}
|
||||
|
||||
for (int i = 0; i < kernel_num; i++) {
|
||||
for (int j = 0; j < kernel_size; j++) {
|
||||
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
|
||||
if (this->fmk_type_ == lite::converter::FmkType_TF) {
|
||||
for (int i = 0; i < kernel_num * kernel_size; i++) {
|
||||
tmp_weight_data[i] = weight_data[i] * trans_scale[i % kernel_num];
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < kernel_num; i++) {
|
||||
for (int j = 0; j < kernel_size; j++) {
|
||||
tmp_weight_data[i * kernel_size + j] = weight_data[i * kernel_size + j] * trans_scale[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,9 @@
|
|||
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "tools/converter/converter_flags.h"
|
||||
|
||||
using mindspore::lite::converter::FmkType;
|
||||
namespace mindspore::opt {
|
||||
class ConvTransformFusion : public PatternProcessPass {
|
||||
public:
|
||||
|
@ -32,6 +34,10 @@ class ConvTransformFusion : public PatternProcessPass {
|
|||
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
|
||||
void CalNewWeightTensor(float *, int, int, const float *) const;
|
||||
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
|
||||
void SetFmkType(FmkType type) { this->fmk_type_ = type; }
|
||||
|
||||
private:
|
||||
FmkType fmk_type_ = lite::converter::FmkType_TF;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
|
Loading…
Reference in New Issue