!8878 [MSLITE] convert fusion module secure check

From: @zhengjun10
Reviewed-by: @hangangqiang,@zhanghaibo5
Signed-off-by: @hangangqiang
This commit is contained in:
mindspore-ci-bot 2020-11-27 11:18:49 +08:00 committed by Gitee
commit 1321483749
19 changed files with 88 additions and 100 deletions

View File

@ -736,7 +736,6 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
return RET_NULL_PTR;
}
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
// nchw->nhwc,offsets need pad 0;
if (axis_map[origin_axis] == 0) {
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};

View File

@ -57,34 +57,23 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
auto optimizer = std::make_shared<opt::GraphOptimizer>();
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
// fusion const_fold
auto cf_pm = std::make_shared<opt::PassManager>("constant folding pass manager", false);
cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>());
cf_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
// for now - trainning is not supporting fuse operations
if (!config->trainModel) {
// remove quantdtype when awaretraining
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
pm->AddPass(std::make_shared<opt::LayerNormFusion>());
pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
schema::ActivationType_RELU6));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU));
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6));
pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
}
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
weight_format_hardcode_pass->SetFmkType(config->fmk);
@ -108,7 +97,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
remove_unused_cast_pass->SetFmkType(config->fmk);
pm->AddPass(remove_unused_cast_pass);
fusion_pm->AddPass(remove_unused_cast_pass);
}
if (config->fmk == lite::converter::FmkType_ONNX) {
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
@ -117,17 +106,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
return nullptr;
}
remove_unused_transpose_pass->SetFmkType(config->fmk);
pm->AddPass(remove_unused_transpose_pass);
fusion_pm->AddPass(remove_unused_transpose_pass);
}
pm->AddPass(std::make_shared<opt::ConvConvFusion>());
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
inne_context_ptr->Init();
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
}
optimizer->AddPassManager(cf_pm);
optimizer->AddPassManager(const_fold_pm);
optimizer->AddPassManager(convert_pm);
optimizer->AddPassManager(pm);
optimizer->AddPassManager(fusion_pm);
optimizer->AddPassManager(graph_pm);
auto new_graph = optimizer->Optimize(old_graph);
if (new_graph == nullptr) {

View File

@ -24,9 +24,9 @@ namespace mindspore {
namespace lite {
namespace converter {
Flags::Flags() {
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", "");
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
AddFlag(&Flags::modelFile, "modelFile",
"Input model file. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
"Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
"");

View File

@ -494,6 +494,14 @@ bool IsPoolingNode(const BaseRef &n) {
return false;
}
bool IsActivationNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Activation;
}
return false;
}
bool IsQuantNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);

View File

@ -65,6 +65,8 @@ bool IsPoolingNode(const BaseRef &n);
bool IsQuantNode(const BaseRef &n);
bool IsActivationNode(const BaseRef &n);
bool CheckIsAllInputsParam(const AnfNodePtr &node);
size_t GetOutputTensorNum(const AnfNodePtr &node);

View File

@ -249,7 +249,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
return nullptr;
}
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context, lite_primitive.get());
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context.get(), lite_primitive.get());
if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
FreeTensors(&input_tensors, &output_tensors);

View File

@ -17,6 +17,8 @@
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
#include <utility>
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/tensor.h"
#include "src/lite_kernel.h"
@ -27,15 +29,13 @@ namespace mindspore {
namespace opt {
class ConstFoldPass : public PatternProcessPass {
public:
explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) {
this->context = new lite::InnerContext;
this->context->Init();
}
~ConstFoldPass() override { delete (this->context); }
explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true)
: PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {}
~ConstFoldPass() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
lite::InnerContext *context = nullptr;
std::shared_ptr<lite::InnerContext> context;
};
} // namespace opt
} // namespace mindspore

View File

@ -29,17 +29,14 @@ constexpr size_t kActivationInputsLength = 2;
}
const BaseRef ConvActivationFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto prim = new schema::PrimitiveT();
prim->value.type = primitive_type;
auto prim_value = std::make_shared<lite::PrimitiveC>(prim);
return VectorRef({prim_value, conv_var});
auto act_var = std::make_shared<CondVar>(IsActivationNode);
return VectorRef({act_var, conv_var});
}
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
return nullptr;
@ -53,7 +50,8 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
MS_ASSERT(act_primitivec != nullptr);
if (act_primitivec->GetType() != activation_type) {
if (act_primitivec->GetType() != schema::ActivationType_RELU &&
act_primitivec->GetType() != schema::ActivationType_RELU6) {
return nullptr;
}
AnfNodePtr pre_node = act_node->input(1);
@ -73,7 +71,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
primc->SetActivationType(activation_type);
primc->SetActivationType(act_primitivec->GetType());
return pre_node;
}
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) {
@ -81,7 +79,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
primc->SetActivationType(activation_type);
primc->SetActivationType(act_primitivec->GetType());
return pre_node;
}
} else {

View File

@ -25,15 +25,11 @@ namespace mindspore {
namespace opt {
class ConvActivationFusion : public PatternProcessPass {
public:
ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion",
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU)
: PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {}
explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
: PatternProcessPass(name, multigraph) {}
~ConvActivationFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
schema::PrimitiveType primitive_type;
schema::ActivationType activation_type;
};
} // namespace opt
} // namespace mindspore

View File

@ -113,8 +113,8 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
// bias --1
// estimated_mean --2
// estimated_variance --3
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
MS_ASSERT(bn_node != nullptr);
MS_ASSERT(trans_bias != nullptr);
MS_ASSERT(trans_scale != nullptr);

View File

@ -25,7 +25,7 @@ class ConvBatchNormFusion : public ConvTransformFusion {
explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {}
~ConvBatchNormFusion() override = default;
const BaseRef DefinePattern() const override;
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
void InitTransParam(const CNodePtr &, int, float *, float *) const override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_

View File

@ -15,11 +15,11 @@
*/
#include "tools/optimizer/fusion/conv_conv_fusion.h"
#include <memory>
#include <functional>
#include "src/ops/primitive_c.h"
#include "src/ops/conv2d.h"
#include <memory>
#include "schema/inner/model_generated.h"
#include "src/ops/conv2d.h"
#include "src/ops/primitive_c.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
@ -128,6 +128,7 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr
for (int k = 0; k < cout0; k++) {
auto up_weight_offset = k * window_size * cin0 + j;
auto down_weight_offset = down_weight_base + k;
auto new_weight_offset = new_weight_base + j;
for (int m = 0; m < window_size; m++) {
new_weight_data[new_weight_offset + cin0 * m] +=

View File

@ -44,8 +44,8 @@ const BaseRef ConvScaleFusion::DefinePattern() const {
auto bias_var = std::make_shared<SeqVar>();
return VectorRef({bn_var, conv_var, weight_var, bias_var});
}
const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale,
float *trans_bias) const {
MS_ASSERT(scale_node != nullptr);
MS_ASSERT(trans_bias != nullptr);
MS_ASSERT(trans_scale != nullptr);

View File

@ -25,7 +25,7 @@ class ConvScaleFusion : public ConvTransformFusion {
explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {}
~ConvScaleFusion() override = default;
const BaseRef DefinePattern() const override;
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
void InitTransParam(const CNodePtr &, int, float *, float *) const override;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_

View File

@ -119,8 +119,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
return pre_node;
}
const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
float *trans_bias) const {
void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
float *trans_bias) const {
if (trans_scale == nullptr) {
MS_LOG(ERROR) << "new transScale failed";
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -145,9 +145,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in
InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
}
const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node,
int kernel_num, const float *trans_scale,
const float *trans_bias) const {
void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num,
const float *trans_scale, const float *trans_bias) const {
MS_ASSERT(conv_node != nullptr);
AnfNodePtr conv_weight_node = nullptr;
AnfNodePtr conv_bias_node = nullptr;
@ -203,8 +202,8 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph,
conv_node->add_input(bias_node);
}
}
const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size,
const float *trans_scale) const {
void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size,
const float *trans_scale) const {
MS_ASSERT(weight_data != nullptr);
MS_ASSERT(trans_scale != nullptr);
auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size];
@ -237,8 +236,8 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
delete[] tmp_weight_data;
}
const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag,
const float *trans_scale, const float *trans_bias) {
void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale,
const float *trans_bias) const {
MS_ASSERT(bias_data != nullptr);
MS_ASSERT(trans_bias != nullptr);
MS_ASSERT(trans_scale != nullptr);

View File

@ -27,11 +27,11 @@ class ConvTransformFusion : public PatternProcessPass {
: PatternProcessPass(name, multigraph) {}
~ConvTransformFusion() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
const void GenTransParam(const CNodePtr &, int, float *, float *) const;
virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0;
const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
const void CalNewWeightTensor(float *, int, int, const float *) const;
static const void CalNewBiasTensor(float *, int, bool, const float *, const float *);
void GenTransParam(const CNodePtr &, int, float *, float *) const;
virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0;
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
void CalNewWeightTensor(float *, int, int, const float *) const;
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_

View File

@ -26,26 +26,27 @@
namespace mindspore::opt {
namespace {
constexpr size_t kActivationInputsLength = 2;
bool IsTupleGetItemNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_TupleGetItem;
}
return false;
}
} // namespace
const BaseRef ConvTupleActivationFusion::DefinePattern() const {
auto conv_var = std::make_shared<CondVar>(IsConvNode);
auto tuple_getitem_var = std::make_shared<CondVar>(IsTupleGetItemNode);
auto tuple_index = std::make_shared<Var>();
auto tuple_prim = new schema::PrimitiveT();
tuple_prim->value.type = schema::PrimitiveType_TupleGetItem;
auto tuple_value = std::make_shared<lite::PrimitiveC>(tuple_prim);
VectorRef tuple_get_item = VectorRef({tuple_value, conv_var, tuple_index});
auto act_prim = new schema::PrimitiveT();
act_prim->value.type = primitive_type;
auto act_value = std::make_shared<lite::PrimitiveC>(act_prim);
return VectorRef({act_value, tuple_get_item});
VectorRef tuple_get_item = VectorRef({tuple_getitem_var, conv_var, tuple_index});
auto act_var = std::make_shared<CondVar>(IsActivationNode);
return VectorRef({act_var, tuple_get_item});
}
const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(node != nullptr);
MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
return nullptr;
}
@ -59,7 +60,8 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
MS_ASSERT(act_primitivec != nullptr);
if (act_primitivec->GetType() != activation_type) {
if (act_primitivec->GetType() != schema::ActivationType_RELU &&
act_primitivec->GetType() != schema::ActivationType_RELU6) {
return nullptr;
}
AnfNodePtr tuple_node = act_node->input(1);
@ -82,7 +84,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
primc->SetActivationType(activation_type);
primc->SetActivationType(act_primitivec->GetType());
conv_node->set_abstract(act_node->abstract());
return conv_node;
}
@ -91,7 +93,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
MS_ASSERT(primc != nullptr);
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
primc->SetActivationType(activation_type);
primc->SetActivationType(act_primitivec->GetType());
conv_node->set_abstract(act_node->abstract());
return conv_node;
}

View File

@ -25,15 +25,11 @@ namespace mindspore {
namespace opt {
class ConvTupleActivationFusion : public PatternProcessPass {
public:
ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion",
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU)
: PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {}
explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion")
: PatternProcessPass(name, multigraph) {}
~ConvTupleActivationFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
schema::PrimitiveType primitive_type;
schema::ActivationType activation_type;
};
} // namespace opt
} // namespace mindspore

View File

@ -24,13 +24,6 @@
namespace mindspore::opt {
namespace {
bool IsActivationNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_Activation;
}
return false;
}
bool IsMulNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);