add param fusion_blacklists

This commit is contained in:
albert-yan 2022-07-05 16:07:27 +08:00
parent bade14590c
commit 7f2f915e55
9 changed files with 76 additions and 44 deletions

View File

@ -27,7 +27,7 @@ class ConvActFusionInoutTest : public ConvFusionInoutTest {
ConvActFusionInoutTest() = default;
protected:
void InitPass() override { this->pass_ = std::make_shared<opt::ConvActivationFusion>(); }
void InitPass() override { this->pass_ = std::make_shared<opt::ConvActivationFusion>(nullptr); }
void InitGraph() override {
this->graph_ = std::make_shared<FuncGraph>();

View File

@ -186,48 +186,55 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
// The training model only does the fusion of the inference part
// remove quantdtype when awaretraining
fusion_pm->AddPass(std::make_shared<opt::AddConcatActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type));
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(param->fmk_type));
fusion_pm->AddPass(std::make_shared<opt::GroupNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion2>());
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::BatchNormToScaleFusion>());
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::ActivationFusion>());
if (param->fullQuantParam.target_device != quant::NVGPU) {
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
std::vector<opt::PassPtr> fusions{std::make_shared<opt::AddConcatActivationFusion>(),
std::make_shared<opt::SqueezeFusion>(),
std::make_shared<opt::TransposeFusion>(),
std::make_shared<opt::ReshapeReshapeFusion>(),
std::make_shared<opt::ConvBiasaddFusion>(),
std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type),
std::make_shared<opt::ConvScaleFusion>(param->fmk_type),
std::make_shared<opt::GroupNormFusion>(),
std::make_shared<opt::TfNormFusion>(),
std::make_shared<opt::OnnxLayerNormFusion>(),
std::make_shared<opt::OnnxLayerNormFusion2>(),
std::make_shared<opt::BatchMatMulFusion>(),
std::make_shared<opt::BatchNormToScaleFusion>(),
std::make_shared<opt::SigmoidMulFusion>(),
std::make_shared<opt::ActivationFusion>(),
std::make_shared<opt::ConvActivationFusion>(param),
std::make_shared<opt::ConvTupleGetItemFusion>(),
std::make_shared<opt::ConvTupleActivationFusion>(),
std::make_shared<opt::TfliteLstmCellFusion>(),
std::make_shared<opt::TfLstmCellFusion>(),
std::make_shared<opt::TfBidirectionGruFusion>(),
std::make_shared<opt::TfGeLUFusion>(),
std::make_shared<opt::OnnxGeLUFusion>(),
std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>(),
std::make_shared<opt::GLUFusion>(),
std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model),
std::make_shared<opt::AffineFusion>(),
std::make_shared<opt::AffineActivationFusion>(),
std::make_shared<opt::ConvConvFusion>(),
std::make_shared<opt::ConvPadFusion>(),
std::make_shared<opt::MatMulAddFusion>(),
std::make_shared<opt::MatMulMulFusion>(),
std::make_shared<opt::TransposeMatMulFusion>(),
std::make_shared<opt::MulAddFusion>(),
std::make_shared<opt::ScaleActivationFusion>(),
std::make_shared<opt::ScaleScaleFusion>(),
std::make_shared<opt::FullConnectedFusion>(),
std::make_shared<opt::FullconnectedAddFusion>(),
std::make_shared<opt::TensorDotFusion>(),
std::make_shared<opt::MatMulActivationFusion>(param)};
for (size_t index = 0; index < fusions.size(); index++) {
auto pass_ptr = fusions.at(index);
auto pass_name = pass_ptr->name();
if (param->fusion_blacklists.find(pass_name) != param->fusion_blacklists.end()) {
MS_LOG(INFO) << "Disable fusion: " << pass_name;
continue;
}
fusion_pm->AddPass(pass_ptr);
}
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
fusion_pm->AddPass(std::make_shared<opt::MatMulMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::TransposeMatMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::MulAddFusion>());
fusion_pm->AddPass(std::make_shared<opt::ScaleActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ScaleScaleFusion>());
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
fusion_pm->AddPass(std::make_shared<opt::FullconnectedAddFusion>());
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>(param));
optimizer->AddPassManager(fusion_pm);
if (optimizer->Optimize(old_graph) == nullptr) {
MS_LOG(ERROR) << "run op fusion failed.";

View File

@ -21,6 +21,7 @@
#include <vector>
#include <set>
#include "backend/common/optimizer/optimizer.h"
#include "backend/common/optimizer/pass.h"
#include "schema/inner/model_generated.h"
#include "tools/common/meta_graph_serializer.h"
#include "ir/anf.h"

View File

@ -183,6 +183,7 @@ int ConfigFileParser::ParseRegistryInfoString(const std::map<std::string, std::m
std::map<std::string, std::string &> parse_map{
{"plugin_path", registry_info_string_.plugin_path},
{"disable_fusion", registry_info_string_.disable_fusion},
{"fusion_blacklists", registry_info_string_.fusion_blacklists},
};
return SetMapData(map, parse_map, kRegistry);
}

View File

@ -60,6 +60,7 @@ struct FullQuantString {
struct RegistryInfoString {
std::string plugin_path;
std::string disable_fusion;
std::string fusion_blacklists;
};
struct AclOptionCfgString {

View File

@ -398,6 +398,13 @@ int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPa
return RET_INPUT_PARAM_INVALID;
}
}
if (!extended_info.fusion_blacklists.empty()) {
std::vector<std::string> fusions = SplitStringToVector(extended_info.fusion_blacklists, ",");
for (const auto &fusion : fusions) {
param->fusion_blacklists.insert(fusion);
}
}
return RET_OK;
}

View File

@ -19,6 +19,7 @@
#include <map>
#include <string>
#include <vector>
#include <set>
#include "include/converter.h"
#include "tools/converter/quantizer/quant_params.h"
#include "tools/converter/preprocess/preprocess_param.h"
@ -56,6 +57,7 @@ struct ConverterPara {
bool pre_infer = false;
bool train_model = false;
bool no_fusion = false;
std::set<std::string> fusion_blacklists;
// inner
std::vector<std::string> plugins_path;

View File

@ -22,6 +22,7 @@
#include "ops/fusion/activation.h"
#include "ops/op_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/quantizer/quant_params.h"
namespace mindspore::opt {
const BaseRef ConvActivationFusion::DefinePattern() const {
@ -34,6 +35,10 @@ const BaseRef ConvActivationFusion::DefinePattern() const {
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
// NVGPU don't support ConvActivationFusion
if (param_->fullQuantParam.target_device == lite::quant::NVGPU) {
return nullptr;
}
if (func_graph == nullptr || node == nullptr) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;

View File

@ -18,19 +18,27 @@
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_ACTIVATION_FUSION_H_
#include <string>
#include <memory>
#include "backend/common/optimizer/optimizer.h"
#include "tools/converter/cxx_api/converter_para.h"
namespace mindspore {
namespace opt {
class ConvActivationFusion : public PatternProcessPass {
public:
explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "ConvActivationFusion")
: PatternProcessPass(name, multigraph) {}
explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> &param, bool multigraph = true,
const std::string &name = "ConvActivationFusion")
: PatternProcessPass(name, multigraph), param_(param) {}
~ConvActivationFusion() override = default;
private:
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
const std::shared_ptr<ConverterPara> param_;
};
} // namespace opt
} // namespace mindspore