forked from mindspore-Ecosystem/mindspore
!8878 [MSLITE] convert fusion module secure check
From: @zhengjun10 Reviewed-by: @hangangqiang,@zhanghaibo5 Signed-off-by: @hangangqiang
This commit is contained in:
commit
1321483749
|
@ -736,7 +736,6 @@ STATUS ChangeOpAxis(schema::MetaGraphT *graph, const std::unique_ptr<schema::CNo
|
|||
MS_LOG(ERROR) << "node->primitive->value.AsCrop() is nullptr";
|
||||
return RET_NULL_PTR;
|
||||
}
|
||||
node->primitive->value.AsCrop()->axis = axis_map[origin_axis];
|
||||
// nchw->nhwc,offsets need pad 0;
|
||||
if (axis_map[origin_axis] == 0) {
|
||||
offsets = {offsets[0], offsets[2], offsets[3], offsets[1]};
|
||||
|
|
|
@ -57,34 +57,23 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
return nullptr;
|
||||
}
|
||||
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||
auto pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
|
||||
auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
|
||||
auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
|
||||
auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
|
||||
|
||||
// fusion const_fold
|
||||
auto cf_pm = std::make_shared<opt::PassManager>("constant folding pass manager", false);
|
||||
cf_pm->AddPass(std::make_shared<opt::ConstFoldPass>());
|
||||
cf_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
|
||||
|
||||
// for now - trainning is not supporting fuse operations
|
||||
if (!config->trainModel) {
|
||||
// remove quantdtype when awaretraining
|
||||
pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
|
||||
pm->AddPass(std::make_shared<opt::LayerNormFusion>());
|
||||
pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
||||
pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
||||
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu", schema::PrimitiveType_Activation,
|
||||
schema::ActivationType_RELU));
|
||||
pm->AddPass(std::make_shared<opt::ConvActivationFusion>(true, "conv_relu6", schema::PrimitiveType_Activation,
|
||||
schema::ActivationType_RELU6));
|
||||
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
|
||||
true, "conv_tuple_relu", schema::PrimitiveType_Activation, schema::ActivationType_RELU));
|
||||
pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>(
|
||||
true, "conv_tuple_relu6", schema::PrimitiveType_Activation, schema::ActivationType_RELU6));
|
||||
pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::RemoveIdentityOpPass>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::LayerNormFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
|
||||
}
|
||||
auto weight_format_hardcode_pass = std::make_shared<opt::WeightFormatHardCodePass>();
|
||||
weight_format_hardcode_pass->SetFmkType(config->fmk);
|
||||
|
@ -108,7 +97,7 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
return nullptr;
|
||||
}
|
||||
remove_unused_cast_pass->SetFmkType(config->fmk);
|
||||
pm->AddPass(remove_unused_cast_pass);
|
||||
fusion_pm->AddPass(remove_unused_cast_pass);
|
||||
}
|
||||
if (config->fmk == lite::converter::FmkType_ONNX) {
|
||||
auto remove_unused_transpose_pass = std::make_shared<opt::RemoveUnusedTransposeOpPass>();
|
||||
|
@ -117,17 +106,22 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
|
|||
return nullptr;
|
||||
}
|
||||
remove_unused_transpose_pass->SetFmkType(config->fmk);
|
||||
pm->AddPass(remove_unused_transpose_pass);
|
||||
fusion_pm->AddPass(remove_unused_transpose_pass);
|
||||
}
|
||||
pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
|
||||
auto inne_context_ptr = std::make_shared<lite::InnerContext>();
|
||||
inne_context_ptr->Init();
|
||||
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));
|
||||
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
|
||||
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
|
||||
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
|
||||
if (config->fmk == lite::converter::FmkType_TFLITE) {
|
||||
convert_pm->AddPass(std::make_shared<opt::GroupDepthwiseOpConvertPass>());
|
||||
convert_pm->AddPass(std::make_shared<opt::TfliteInputsOrderExchangePass>());
|
||||
}
|
||||
optimizer->AddPassManager(cf_pm);
|
||||
optimizer->AddPassManager(const_fold_pm);
|
||||
optimizer->AddPassManager(convert_pm);
|
||||
optimizer->AddPassManager(pm);
|
||||
optimizer->AddPassManager(fusion_pm);
|
||||
optimizer->AddPassManager(graph_pm);
|
||||
auto new_graph = optimizer->Optimize(old_graph);
|
||||
if (new_graph == nullptr) {
|
||||
|
|
|
@ -24,9 +24,9 @@ namespace mindspore {
|
|||
namespace lite {
|
||||
namespace converter {
|
||||
Flags::Flags() {
|
||||
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TFLITE | CAFFE | MINDIR | ONNX", "");
|
||||
AddFlag(&Flags::fmkIn, "fmk", "Input model framework type. TF | TFLITE | CAFFE | MINDIR | ONNX", "");
|
||||
AddFlag(&Flags::modelFile, "modelFile",
|
||||
"Input model file. TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
|
||||
"Input model file. TF: *.pb | TFLITE: *.tflite | CAFFE: *.prototxt | MINDIR: *.mindir | ONNX: *.onnx", "");
|
||||
AddFlag(&Flags::outputFile, "outputFile", "Output model file path. Will add .ms automatically", "");
|
||||
AddFlag(&Flags::weightFile, "weightFile", "Input model weight file. Needed when fmk is CAFFE. CAFFE: *.caffemodel",
|
||||
"");
|
||||
|
|
|
@ -494,6 +494,14 @@ bool IsPoolingNode(const BaseRef &n) {
|
|||
return false;
|
||||
}
|
||||
|
||||
bool IsActivationNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
|
||||
auto type = opt::GetCNodeType(n);
|
||||
return type == schema::PrimitiveType_Activation;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsQuantNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
|
||||
auto type = opt::GetCNodeType(n);
|
||||
|
|
|
@ -65,6 +65,8 @@ bool IsPoolingNode(const BaseRef &n);
|
|||
|
||||
bool IsQuantNode(const BaseRef &n);
|
||||
|
||||
bool IsActivationNode(const BaseRef &n);
|
||||
|
||||
bool CheckIsAllInputsParam(const AnfNodePtr &node);
|
||||
|
||||
size_t GetOutputTensorNum(const AnfNodePtr &node);
|
||||
|
|
|
@ -249,7 +249,7 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
|
|||
<< schema::EnumNamePrimitiveType((schema::PrimitiveType)(lite_primitive->Type()));
|
||||
return nullptr;
|
||||
}
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context, lite_primitive.get());
|
||||
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, context.get(), lite_primitive.get());
|
||||
if (lite_kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";
|
||||
FreeTensors(&input_tensors, &output_tensors);
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONSTANT_FOLDING_FUSION_H_
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/tensor.h"
|
||||
#include "src/lite_kernel.h"
|
||||
|
@ -27,15 +29,13 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ConstFoldPass : public PatternProcessPass {
|
||||
public:
|
||||
explicit ConstFoldPass(bool multigraph = true) : PatternProcessPass("constfold_pass", multigraph) {
|
||||
this->context = new lite::InnerContext;
|
||||
this->context->Init();
|
||||
}
|
||||
~ConstFoldPass() override { delete (this->context); }
|
||||
explicit ConstFoldPass(std::shared_ptr<lite::InnerContext> context_ptr = nullptr, bool multigraph = true)
|
||||
: PatternProcessPass("constfold_pass", multigraph), context(std::move(context_ptr)) {}
|
||||
~ConstFoldPass() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
lite::InnerContext *context = nullptr;
|
||||
std::shared_ptr<lite::InnerContext> context;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,17 +29,14 @@ constexpr size_t kActivationInputsLength = 2;
|
|||
}
|
||||
const BaseRef ConvActivationFusion::DefinePattern() const {
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto prim = new schema::PrimitiveT();
|
||||
prim->value.type = primitive_type;
|
||||
auto prim_value = std::make_shared<lite::PrimitiveC>(prim);
|
||||
return VectorRef({prim_value, conv_var});
|
||||
auto act_var = std::make_shared<CondVar>(IsActivationNode);
|
||||
return VectorRef({act_var, conv_var});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
return nullptr;
|
||||
|
@ -53,7 +50,8 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
|
||||
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
|
||||
MS_ASSERT(act_primitivec != nullptr);
|
||||
if (act_primitivec->GetType() != activation_type) {
|
||||
if (act_primitivec->GetType() != schema::ActivationType_RELU &&
|
||||
act_primitivec->GetType() != schema::ActivationType_RELU6) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr pre_node = act_node->input(1);
|
||||
|
@ -73,7 +71,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
|
||||
primc->SetActivationType(activation_type);
|
||||
primc->SetActivationType(act_primitivec->GetType());
|
||||
return pre_node;
|
||||
}
|
||||
} else if (node_type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||
|
@ -81,7 +79,7 @@ const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, c
|
|||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
|
||||
primc->SetActivationType(activation_type);
|
||||
primc->SetActivationType(act_primitivec->GetType());
|
||||
return pre_node;
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -25,15 +25,11 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ConvActivationFusion : public PatternProcessPass {
|
||||
public:
|
||||
ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion",
|
||||
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
|
||||
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU)
|
||||
: PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {}
|
||||
explicit ConvActivationFusion(bool multigraph = true, const std::string &name = "conv_activation_fusion")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~ConvActivationFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
schema::PrimitiveType primitive_type;
|
||||
schema::ActivationType activation_type;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -113,8 +113,8 @@ const BaseRef ConvBatchNormFusion::DefinePattern() const {
|
|||
// bias --1
|
||||
// estimated_mean --2
|
||||
// estimated_variance --3
|
||||
const void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
void ConvBatchNormFusion::InitTransParam(const CNodePtr &bn_node, int kernel_num, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
MS_ASSERT(bn_node != nullptr);
|
||||
MS_ASSERT(trans_bias != nullptr);
|
||||
MS_ASSERT(trans_scale != nullptr);
|
||||
|
|
|
@ -25,7 +25,7 @@ class ConvBatchNormFusion : public ConvTransformFusion {
|
|||
explicit ConvBatchNormFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_batchnorm_fusion") {}
|
||||
~ConvBatchNormFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
|
||||
void InitTransParam(const CNodePtr &, int, float *, float *) const override;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_BN_FUSION_H_
|
||||
|
|
|
@ -15,11 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/conv_conv_fusion.h"
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/conv2d.h"
|
||||
#include <memory>
|
||||
#include "schema/inner/model_generated.h"
|
||||
#include "src/ops/conv2d.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
|
@ -128,6 +128,7 @@ STATUS GenNewConvWeight(const ParameterPtr &down_weight_node, const ParameterPtr
|
|||
for (int k = 0; k < cout0; k++) {
|
||||
auto up_weight_offset = k * window_size * cin0 + j;
|
||||
auto down_weight_offset = down_weight_base + k;
|
||||
|
||||
auto new_weight_offset = new_weight_base + j;
|
||||
for (int m = 0; m < window_size; m++) {
|
||||
new_weight_data[new_weight_offset + cin0 * m] +=
|
||||
|
|
|
@ -44,8 +44,8 @@ const BaseRef ConvScaleFusion::DefinePattern() const {
|
|||
auto bias_var = std::make_shared<SeqVar>();
|
||||
return VectorRef({bn_var, conv_var, weight_var, bias_var});
|
||||
}
|
||||
const void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
void ConvScaleFusion::InitTransParam(const CNodePtr &scale_node, int kernel_num, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
MS_ASSERT(scale_node != nullptr);
|
||||
MS_ASSERT(trans_bias != nullptr);
|
||||
MS_ASSERT(trans_scale != nullptr);
|
||||
|
|
|
@ -25,7 +25,7 @@ class ConvScaleFusion : public ConvTransformFusion {
|
|||
explicit ConvScaleFusion(bool multigraph = true) : ConvTransformFusion(multigraph, "conv_scale_fusion") {}
|
||||
~ConvScaleFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const void InitTransParam(const CNodePtr &, int, float *, float *) const override;
|
||||
void InitTransParam(const CNodePtr &, int, float *, float *) const override;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||
|
|
|
@ -119,8 +119,8 @@ const AnfNodePtr ConvTransformFusion::Process(const FuncGraphPtr &func_graph, co
|
|||
return pre_node;
|
||||
}
|
||||
|
||||
const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, int kernel_nums, float *trans_scale,
|
||||
float *trans_bias) const {
|
||||
if (trans_scale == nullptr) {
|
||||
MS_LOG(ERROR) << "new transScale failed";
|
||||
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
|
||||
|
@ -145,9 +145,8 @@ const void ConvTransformFusion::GenTransParam(const CNodePtr &transform_node, in
|
|||
InitTransParam(transform_node, kernel_nums, trans_scale, trans_bias);
|
||||
}
|
||||
|
||||
const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node,
|
||||
int kernel_num, const float *trans_scale,
|
||||
const float *trans_bias) const {
|
||||
void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, int kernel_num,
|
||||
const float *trans_scale, const float *trans_bias) const {
|
||||
MS_ASSERT(conv_node != nullptr);
|
||||
AnfNodePtr conv_weight_node = nullptr;
|
||||
AnfNodePtr conv_bias_node = nullptr;
|
||||
|
@ -203,8 +202,8 @@ const void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph,
|
|||
conv_node->add_input(bias_node);
|
||||
}
|
||||
}
|
||||
const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size,
|
||||
const float *trans_scale) const {
|
||||
void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kernel_num, int kernel_size,
|
||||
const float *trans_scale) const {
|
||||
MS_ASSERT(weight_data != nullptr);
|
||||
MS_ASSERT(trans_scale != nullptr);
|
||||
auto tmp_weight_data = new (std::nothrow) float[kernel_num * kernel_size];
|
||||
|
@ -237,8 +236,8 @@ const void ConvTransformFusion::CalNewWeightTensor(float *weight_data, int kerne
|
|||
delete[] tmp_weight_data;
|
||||
}
|
||||
|
||||
const void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag,
|
||||
const float *trans_scale, const float *trans_bias) {
|
||||
void ConvTransformFusion::CalNewBiasTensor(float *bias_data, int kernel_num, bool bias_flag, const float *trans_scale,
|
||||
const float *trans_bias) const {
|
||||
MS_ASSERT(bias_data != nullptr);
|
||||
MS_ASSERT(trans_bias != nullptr);
|
||||
MS_ASSERT(trans_scale != nullptr);
|
||||
|
|
|
@ -27,11 +27,11 @@ class ConvTransformFusion : public PatternProcessPass {
|
|||
: PatternProcessPass(name, multigraph) {}
|
||||
~ConvTransformFusion() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
const void GenTransParam(const CNodePtr &, int, float *, float *) const;
|
||||
virtual const void InitTransParam(const CNodePtr &, int, float *, float *) const = 0;
|
||||
const void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
|
||||
const void CalNewWeightTensor(float *, int, int, const float *) const;
|
||||
static const void CalNewBiasTensor(float *, int, bool, const float *, const float *);
|
||||
void GenTransParam(const CNodePtr &, int, float *, float *) const;
|
||||
virtual void InitTransParam(const CNodePtr &, int, float *, float *) const = 0;
|
||||
void GenNewConvTensor(const FuncGraphPtr &, const CNodePtr &, int, const float *, const float *) const;
|
||||
void CalNewWeightTensor(float *, int, int, const float *) const;
|
||||
void CalNewBiasTensor(float *, int, bool, const float *, const float *) const;
|
||||
};
|
||||
} // namespace mindspore::opt
|
||||
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_TRANSFORM_FUSION_H_
|
||||
|
|
|
@ -26,26 +26,27 @@
|
|||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr size_t kActivationInputsLength = 2;
|
||||
bool IsTupleGetItemNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
|
||||
auto type = opt::GetCNodeType(n);
|
||||
return type == schema::PrimitiveType_TupleGetItem;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
const BaseRef ConvTupleActivationFusion::DefinePattern() const {
|
||||
auto conv_var = std::make_shared<CondVar>(IsConvNode);
|
||||
auto tuple_getitem_var = std::make_shared<CondVar>(IsTupleGetItemNode);
|
||||
auto tuple_index = std::make_shared<Var>();
|
||||
auto tuple_prim = new schema::PrimitiveT();
|
||||
tuple_prim->value.type = schema::PrimitiveType_TupleGetItem;
|
||||
auto tuple_value = std::make_shared<lite::PrimitiveC>(tuple_prim);
|
||||
VectorRef tuple_get_item = VectorRef({tuple_value, conv_var, tuple_index});
|
||||
|
||||
auto act_prim = new schema::PrimitiveT();
|
||||
act_prim->value.type = primitive_type;
|
||||
auto act_value = std::make_shared<lite::PrimitiveC>(act_prim);
|
||||
return VectorRef({act_value, tuple_get_item});
|
||||
VectorRef tuple_get_item = VectorRef({tuple_getitem_var, conv_var, tuple_index});
|
||||
auto act_var = std::make_shared<CondVar>(IsActivationNode);
|
||||
return VectorRef({act_var, tuple_get_item});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_LOG(DEBUG) << "conv tuple activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
|
||||
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -59,7 +60,8 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
|
|||
MS_ASSERT(utils::isa<std::shared_ptr<mindspore::lite::Activation>>(primitivec));
|
||||
auto act_primitivec = utils::cast<std::shared_ptr<mindspore::lite::Activation>>(primitivec);
|
||||
MS_ASSERT(act_primitivec != nullptr);
|
||||
if (act_primitivec->GetType() != activation_type) {
|
||||
if (act_primitivec->GetType() != schema::ActivationType_RELU &&
|
||||
act_primitivec->GetType() != schema::ActivationType_RELU6) {
|
||||
return nullptr;
|
||||
}
|
||||
AnfNodePtr tuple_node = act_node->input(1);
|
||||
|
@ -82,7 +84,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
|
|||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::Conv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
|
||||
primc->SetActivationType(activation_type);
|
||||
primc->SetActivationType(act_primitivec->GetType());
|
||||
conv_node->set_abstract(act_node->abstract());
|
||||
return conv_node;
|
||||
}
|
||||
|
@ -91,7 +93,7 @@ const AnfNodePtr ConvTupleActivationFusion::Process(const FuncGraphPtr &func_gra
|
|||
auto primc = utils::cast<std::shared_ptr<mindspore::lite::DepthwiseConv2D>>(primitive_c);
|
||||
MS_ASSERT(primc != nullptr);
|
||||
if (primc->GetActivationType() == schema::ActivationType_NO_ACTIVATION) {
|
||||
primc->SetActivationType(activation_type);
|
||||
primc->SetActivationType(act_primitivec->GetType());
|
||||
conv_node->set_abstract(act_node->abstract());
|
||||
return conv_node;
|
||||
}
|
||||
|
|
|
@ -25,15 +25,11 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
class ConvTupleActivationFusion : public PatternProcessPass {
|
||||
public:
|
||||
ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion",
|
||||
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
|
||||
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU)
|
||||
: PatternProcessPass(name, multigraph), primitive_type(primitive), activation_type(activation) {}
|
||||
explicit ConvTupleActivationFusion(bool multigraph = true, const std::string &name = "conv_tuple_activation_fusion")
|
||||
: PatternProcessPass(name, multigraph) {}
|
||||
~ConvTupleActivationFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
schema::PrimitiveType primitive_type;
|
||||
schema::ActivationType activation_type;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,13 +24,6 @@
|
|||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
bool IsActivationNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
|
||||
auto type = opt::GetCNodeType(n);
|
||||
return type == schema::PrimitiveType_Activation;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
bool IsMulNode(const BaseRef &n) {
|
||||
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
|
||||
auto type = opt::GetCNodeType(n);
|
||||
|
|
Loading…
Reference in New Issue