add param fusion_blacklists
This commit is contained in:
parent
bade14590c
commit
7f2f915e55
|
@ -27,7 +27,7 @@ class ConvActFusionInoutTest : public ConvFusionInoutTest {
|
|||
ConvActFusionInoutTest() = default;
|
||||
|
||||
protected:
|
||||
void InitPass() override { this->pass_ = std::make_shared<opt::ConvActivationFusion>(); }
|
||||
void InitPass() override { this->pass_ = std::make_shared<opt::ConvActivationFusion>(nullptr); }
|
||||
|
||||
void InitGraph() override {
|
||||
this->graph_ = std::make_shared<FuncGraph>();
|
||||
|
|
|
@ -186,48 +186,55 @@ int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared
|
|||
|
||||
// The training model only does the fusion of the inference part
|
||||
// remove quantdtype when awaretraining
|
||||
fusion_pm->AddPass(std::make_shared<opt::AddConcatActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type));
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(param->fmk_type));
|
||||
fusion_pm->AddPass(std::make_shared<opt::GroupNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion2>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::BatchNormToScaleFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ActivationFusion>());
|
||||
if (param->fullQuantParam.target_device != quant::NVGPU) {
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
|
||||
std::vector<opt::PassPtr> fusions{std::make_shared<opt::AddConcatActivationFusion>(),
|
||||
std::make_shared<opt::SqueezeFusion>(),
|
||||
std::make_shared<opt::TransposeFusion>(),
|
||||
std::make_shared<opt::ReshapeReshapeFusion>(),
|
||||
std::make_shared<opt::ConvBiasaddFusion>(),
|
||||
std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type),
|
||||
std::make_shared<opt::ConvScaleFusion>(param->fmk_type),
|
||||
std::make_shared<opt::GroupNormFusion>(),
|
||||
std::make_shared<opt::TfNormFusion>(),
|
||||
std::make_shared<opt::OnnxLayerNormFusion>(),
|
||||
std::make_shared<opt::OnnxLayerNormFusion2>(),
|
||||
std::make_shared<opt::BatchMatMulFusion>(),
|
||||
std::make_shared<opt::BatchNormToScaleFusion>(),
|
||||
std::make_shared<opt::SigmoidMulFusion>(),
|
||||
std::make_shared<opt::ActivationFusion>(),
|
||||
std::make_shared<opt::ConvActivationFusion>(param),
|
||||
std::make_shared<opt::ConvTupleGetItemFusion>(),
|
||||
std::make_shared<opt::ConvTupleActivationFusion>(),
|
||||
std::make_shared<opt::TfliteLstmCellFusion>(),
|
||||
std::make_shared<opt::TfLstmCellFusion>(),
|
||||
std::make_shared<opt::TfBidirectionGruFusion>(),
|
||||
std::make_shared<opt::TfGeLUFusion>(),
|
||||
std::make_shared<opt::OnnxGeLUFusion>(),
|
||||
std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>(),
|
||||
std::make_shared<opt::GLUFusion>(),
|
||||
std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model),
|
||||
std::make_shared<opt::AffineFusion>(),
|
||||
std::make_shared<opt::AffineActivationFusion>(),
|
||||
std::make_shared<opt::ConvConvFusion>(),
|
||||
std::make_shared<opt::ConvPadFusion>(),
|
||||
std::make_shared<opt::MatMulAddFusion>(),
|
||||
std::make_shared<opt::MatMulMulFusion>(),
|
||||
std::make_shared<opt::TransposeMatMulFusion>(),
|
||||
std::make_shared<opt::MulAddFusion>(),
|
||||
std::make_shared<opt::ScaleActivationFusion>(),
|
||||
std::make_shared<opt::ScaleScaleFusion>(),
|
||||
std::make_shared<opt::FullConnectedFusion>(),
|
||||
std::make_shared<opt::FullconnectedAddFusion>(),
|
||||
std::make_shared<opt::TensorDotFusion>(),
|
||||
std::make_shared<opt::MatMulActivationFusion>(param)};
|
||||
for (size_t index = 0; index < fusions.size(); index++) {
|
||||
auto pass_ptr = fusions.at(index);
|
||||
auto pass_name = pass_ptr->name();
|
||||
if (param->fusion_blacklists.find(pass_name) != param->fusion_blacklists.end()) {
|
||||
MS_LOG(INFO) << "Disable fusion: " << pass_name;
|
||||
continue;
|
||||
}
|
||||
fusion_pm->AddPass(pass_ptr);
|
||||
}
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TransposeMatMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::MulAddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ScaleActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ScaleScaleFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::FullConnectedFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::FullconnectedAddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::TensorDotFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::MatMulActivationFusion>(param));
|
||||
optimizer->AddPassManager(fusion_pm);
|
||||
if (optimizer->Optimize(old_graph) == nullptr) {
|
||||
MS_LOG(ERROR) << "run op fusion failed.";
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
#include <set>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "tools/common/meta_graph_serializer.h"
|
||||
#include "ir/anf.h"
|
||||
|
|
|
@ -183,6 +183,7 @@ int ConfigFileParser::ParseRegistryInfoString(const std::map<std::string, std::m
|
|||
std::map<std::string, std::string &> parse_map{
|
||||
{"plugin_path", registry_info_string_.plugin_path},
|
||||
{"disable_fusion", registry_info_string_.disable_fusion},
|
||||
{"fusion_blacklists", registry_info_string_.fusion_blacklists},
|
||||
};
|
||||
return SetMapData(map, parse_map, kRegistry);
|
||||
}
|
||||
|
|
|
@ -60,6 +60,7 @@ struct FullQuantString {
|
|||
struct RegistryInfoString {
|
||||
std::string plugin_path;
|
||||
std::string disable_fusion;
|
||||
std::string fusion_blacklists;
|
||||
};
|
||||
|
||||
struct AclOptionCfgString {
|
||||
|
|
|
@ -398,6 +398,13 @@ int ConverterImpl::InitExtendedIntegrationInfo(const std::shared_ptr<ConverterPa
|
|||
return RET_INPUT_PARAM_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
if (!extended_info.fusion_blacklists.empty()) {
|
||||
std::vector<std::string> fusions = SplitStringToVector(extended_info.fusion_blacklists, ",");
|
||||
for (const auto &fusion : fusions) {
|
||||
param->fusion_blacklists.insert(fusion);
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "include/converter.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
#include "tools/converter/preprocess/preprocess_param.h"
|
||||
|
@ -56,6 +57,7 @@ struct ConverterPara {
|
|||
bool pre_infer = false;
|
||||
bool train_model = false;
|
||||
bool no_fusion = false;
|
||||
std::set<std::string> fusion_blacklists;
|
||||
|
||||
// inner
|
||||
std::vector<std::string> plugins_path;
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "ops/fusion/activation.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
#include "tools/converter/quantizer/quant_params.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
const BaseRef ConvActivationFusion::DefinePattern() const {
|
||||
|
@ -34,6 +35,10 @@ const BaseRef ConvActivationFusion::DefinePattern() const {
|
|||
|
||||
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
// NVGPU don't support ConvActivationFusion
|
||||
if (param_->fullQuantParam.target_device == lite::quant::NVGPU) {
|
||||
return nullptr;
|
||||
}
|
||||
if (func_graph == nullptr || node == nullptr) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
|
|
|
@ -18,19 +18,27 @@
|
|||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_CONV_ACTIVATION_FUSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "tools/converter/cxx_api/converter_para.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvActivationFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "ConvActivationFusion")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
explicit ConvActivationFusion(const std::shared_ptr<ConverterPara> ¶m, bool multigraph = true,
|
||||
const std::string &name = "ConvActivationFusion")
|
||||
: PatternProcessPass(name, multigraph), param_(param) {}
|
||||
|
||||
~ConvActivationFusion() override = default;
|
||||
|
||||
private:
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
const std::shared_ptr<ConverterPara> param_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue