diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index b1396d5b689..39266199203 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -38,6 +38,7 @@ #include "backend/optimizer/ascend/ir_fusion/square_sum_fusion.h" #include "backend/optimizer/ascend/ir_fusion/clip_by_norm_no_div_square_sum_fusion.h" #include "backend/optimizer/ascend/ir_fusion/lamb_update_with_lr_rule_fusion.h" +#include "backend/optimizer/ascend/ir_fusion/prelu_fusion.h" #include "backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.h" #include "backend/optimizer/ascend/ir_fusion/confusion_softmax_grad_rule.h" #include "backend/optimizer/ascend/ir_fusion/lamb_next_mv_rule.h" @@ -165,6 +166,7 @@ void AddAscendIRFusionRulesPass(PassManager *ir_fusion_pm) { ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); } void AddAscendIRFusionPass(PassManager *ir_fusion_pm) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.cc new file mode 100644 index 00000000000..1c8ca3d4bbd --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/ascend/ir_fusion/prelu_fusion.h" + +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { +const BaseRef PReluFusion::DefinePattern() const { + VectorRef x_pattern({prim::kPrimRelu, VectorRef({prim::kPrimNeg, x_})}); + VectorRef mul_pattern({prim::kPrimMul, VectorRef({prim::kPrimNeg, weight_}), x_pattern}); + VectorRef pattern({prim::kPrimAdd, VectorRef({prim::kPrimRelu, x_}), mul_pattern}); + return pattern; +} + +const AnfNodePtr PReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); + + BaseRef &x_gnode = (*equiv)[x_]; + BaseRef &weight_gnode = (*equiv)[weight_]; + + auto x = utils::cast(x_gnode); + auto weight = utils::cast(weight_gnode); + + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(weight); + + auto prim = std::make_shared(kPReluOpName); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), x, weight}; + auto fusion_node = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fusion_node); + fusion_node->set_abstract(node->abstract()); + fusion_node->set_scope(node->scope()); + return fusion_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.h new file mode 100644 index 00000000000..5dc1c2e641f --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/prelu_fusion.h @@ -0,0 +1,41 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ + +#include +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class PReluFusion : public PatternProcessPass { + public: + explicit PReluFusion(bool multigraph = true) : PatternProcessPass("prelu_fusion", multigraph) { + x_ = std::make_shared(); + weight_ = std::make_shared(); + } + ~PReluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr weight_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PRELU_FUSION_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 172683a22ca..173642f7162 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -129,6 +129,7 @@ constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; constexpr auto kSquareSumV1OpName = "SquareSumV1"; constexpr auto kSquareSumV2OpName = "SquareSumV2"; constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; +constexpr auto kPReluOpName = "PReLU"; constexpr auto kGreaterOpName = "Greater"; constexpr auto kSqrtOpName = "Sqrt"; constexpr auto kRsqrtOpName = "Rsqrt";