diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc index 1ecf4bbd06d..3f905fedf9f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.cc @@ -15,43 +15,9 @@ */ #include "pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h" #include "pre_activate/common/helper.h" -#include "utils/utils.h" namespace mindspore { namespace opt { -namespace { -void GetAdd0AndAdd1(const AnfNodePtr &sub0, AnfNodePtr *add0, AnfNodePtr *add1) { - MS_EXCEPTION_IF_NULL(sub0); - MS_EXCEPTION_IF_NULL(add0); - MS_EXCEPTION_IF_NULL(add1); - auto sub0_cnode = sub0->cast(); - MS_EXCEPTION_IF_NULL(sub0_cnode); - CheckCNodeInputSize(sub0_cnode, kSubInputNum); - AnfNodePtr mul4 = sub0_cnode->input(2); - MS_EXCEPTION_IF_NULL(mul4); - auto mul4_cnode = mul4->cast(); - MS_EXCEPTION_IF_NULL(mul4_cnode); - CheckCNodeInputSize(mul4_cnode, kMulInputNum); - AnfNodePtr true_div0 = mul4_cnode->input(2); - MS_EXCEPTION_IF_NULL(true_div0); - auto true_div0_cnode = true_div0->cast(); - MS_EXCEPTION_IF_NULL(true_div0_cnode); - CheckCNodeInputSize(true_div0_cnode, kRealDivInputNum); - *add0 = true_div0_cnode->input(1); - AnfNodePtr add2 = true_div0_cnode->input(2); - MS_EXCEPTION_IF_NULL(add2); - auto add2_cnode = add2->cast(); - MS_EXCEPTION_IF_NULL(add2_cnode); - CheckCNodeInputSize(add2_cnode, kAddInputNum); - AnfNodePtr sqrt0 = add2_cnode->input(1); - MS_EXCEPTION_IF_NULL(sqrt0); - auto sqrt0_cnode = sqrt0->cast(); - MS_EXCEPTION_IF_NULL(sqrt0_cnode); - CheckCNodeInputSize(sqrt0_cnode, kSqrtInputNum); - *add1 = sqrt0_cnode->input(1); -} -} // namespace - AnfNodePtr AdamApplyOneFusion::CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); @@ -79,10 +45,10 @@ const BaseRef AdamApplyOneFusion::DefinePattern() const { const auto prim_deal_div = std::make_shared(kRealDivOpName); VectorRef mul2 = VectorRef({prim::kPrimMul, mul_x_input_vars_[2], input_vars_[1]}); VectorRef mul3 = VectorRef({prim::kPrimMul, mul_x_input_vars_[3], VectorRef({prim::kPrimSquare, input_vars_[0]})}); - VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({prim::kPrimTensorAdd, mul2, mul3})}); + VectorRef sqrt0 = VectorRef({prim_sqrt, VectorRef({add1_var_, mul2, mul3})}); VectorRef mul1 = VectorRef({prim::kPrimMul, mul_x_input_vars_[1], input_vars_[0]}); VectorRef mul0 = VectorRef({prim::kPrimMul, mul_x_input_vars_[0], input_vars_[2]}); - VectorRef add0 = VectorRef({prim::kPrimTensorAdd, mul0, mul1}); + VectorRef add0 = VectorRef({add0_var_, mul0, mul1}); VectorRef true_div0 = VectorRef({prim_deal_div, add0, VectorRef({prim::kPrimTensorAdd, sqrt0, add2_y_})}); return VectorRef({prim::kPrimSub, input_vars_[3], VectorRef({prim::kPrimMul, input_vars_[4], true_div0})}); } @@ -96,10 +62,17 @@ const AnfNodePtr AdamApplyOneFusion::Process(const FuncGraphPtr &func_graph, con new_node->set_scope(node->scope()); // Set abstract of new node AbstractBasePtrList new_node_abstract_list; - AnfNodePtr add0 = nullptr; - AnfNodePtr add1 = nullptr; - GetAdd0AndAdd1(node, &add0, &add1); + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); MS_EXCEPTION_IF_NULL(add1); new_node_abstract_list.push_back(add1->abstract()); new_node_abstract_list.push_back(add0->abstract()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h index 6642561b076..77f66414637 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_fusion.h @@ -19,6 +19,7 @@ #include #include #include "pre_activate/common/optimizer.h" +#include "utils/utils.h" namespace mindspore { namespace opt { @@ -35,6 +36,8 @@ class AdamApplyOneFusion : public PatternProcessPass { mul_x_input_vars_.push_back(std::make_shared()); } add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); } ~AdamApplyOneFusion() override = default; @@ -46,6 +49,8 @@ class AdamApplyOneFusion : public PatternProcessPass { std::vector input_vars_; std::vector mul_x_input_vars_; VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc index 442aa64217a..4a2387d3cc3 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.cc @@ -17,48 +17,13 @@ #include #include -#include #include "session/anf_runtime_algorithm.h" #include "ir/primitive.h" -#include "utils/utils.h" #include "pre_activate/common/helper.h" namespace mindspore { namespace opt { -namespace { -std::tuple GetAdd0Add1Node(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto sub0 = node->cast(); - MS_EXCEPTION_IF_NULL(sub0); - auto mul5_anf = sub0->input(2); - MS_EXCEPTION_IF_NULL(mul5_anf); - auto mul5 = mul5_anf->cast(); - MS_EXCEPTION_IF_NULL(mul5); - auto add3_anf = mul5->input(2); - MS_EXCEPTION_IF_NULL(add3_anf); - auto add3 = add3_anf->cast(); - MS_EXCEPTION_IF_NULL(add3); - auto real_div0_anf = add3->input(1); - MS_EXCEPTION_IF_NULL(real_div0_anf); - auto real_div0 = real_div0_anf->cast(); - MS_EXCEPTION_IF_NULL(real_div0); - auto add0_anf = real_div0->input(1); - MS_EXCEPTION_IF_NULL(add0_anf); - auto add2_anf = real_div0->input(2); - MS_EXCEPTION_IF_NULL(add2_anf); - auto add2 = add2_anf->cast(); - MS_EXCEPTION_IF_NULL(add2); - auto sqrt0_anf = add2->input(1); - MS_EXCEPTION_IF_NULL(sqrt0_anf); - auto sqrt0 = sqrt0_anf->cast(); - MS_EXCEPTION_IF_NULL(sqrt0); - auto add1_anf = sqrt0->input(1); - MS_EXCEPTION_IF_NULL(add1_anf); - return std::make_tuple(add0_anf, add1_anf); -} -} // namespace - std::vector AdamApplyOneWithDecayRule::GetFusionNodeInputs(const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(equiv); auto input0 = utils::cast((*equiv)[input0_]); @@ -82,10 +47,10 @@ const BaseRef AdamApplyOneWithDecayRule::DefinePattern() const { VectorRef mul0_pattern({prim::kPrimMul, mul0_x_, input2_}); VectorRef mul1_pattern({prim::kPrimMul, mul1_x_, input0_}); VectorRef square0_pattern({prim::kPrimSquare, input0_}); - VectorRef add0_pattern({prim::kPrimTensorAdd, mul0_pattern, mul1_pattern}); + VectorRef add0_pattern({add0_var_, mul0_pattern, mul1_pattern}); VectorRef mul2_pattern({prim::kPrimMul, mul2_x_, input1_}); VectorRef mul3_pattern({prim::kPrimMul, mul3_x_, square0_pattern}); - VectorRef add1_pattern({prim::kPrimTensorAdd, mul2_pattern, mul3_pattern}); + VectorRef add1_pattern({add1_var_, mul2_pattern, mul3_pattern}); VectorRef sqrt0_pattern({sqrt, add1_pattern}); VectorRef add2_pattern({prim::kPrimTensorAdd, sqrt0_pattern, add2_y_}); VectorRef mul4_pattern({prim::kPrimMul, mul4_x_, input3_}); @@ -107,9 +72,18 @@ const AnfNodePtr AdamApplyOneWithDecayRule::Process(const FuncGraphPtr &graph, c MS_EXCEPTION_IF_NULL(fusion_node); fusion_node->set_scope(node->scope()); - AnfNodePtr add0 = nullptr; - AnfNodePtr add1 = nullptr; - std::tie(add0, add1) = GetAdd0Add1Node(node); + auto iter_add0 = (*equiv).find(add0_var_); + if (iter_add0 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add0 var after matched."; + } + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add0 = utils::cast(iter_add0->second); + MS_EXCEPTION_IF_NULL(add0); + auto add1 = utils::cast(iter_add1->second); + MS_EXCEPTION_IF_NULL(add1); auto types = {AnfAlgo::GetOutputInferDataType(add1, 0), AnfAlgo::GetOutputInferDataType(add0, 0), AnfAlgo::GetOutputInferDataType(node, 0)}; auto shapes = {AnfAlgo::GetOutputInferShape(add1, 0), AnfAlgo::GetOutputInferShape(add0, 0), diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h index a6bab48770b..72c54f35352 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/adam_apply_one_with_decay_rule.h @@ -19,6 +19,7 @@ #include #include #include "pre_activate/common/optimizer.h" +#include "utils/utils.h" namespace mindspore { namespace opt { class AdamApplyOneWithDecayRule : public PatternProcessPass { @@ -36,6 +37,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { mul3_x_ = std::make_shared(); mul4_x_ = std::make_shared(); add2_y_ = std::make_shared(); + add0_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); + add1_var_ = std::make_shared(std::make_shared(prim::kPrimTensorAdd->name())); } ~AdamApplyOneWithDecayRule() override = default; const BaseRef DefinePattern() const override; @@ -54,6 +57,8 @@ class AdamApplyOneWithDecayRule : public PatternProcessPass { VarPtr mul3_x_; VarPtr mul4_x_; VarPtr add2_y_; + VarPtr add0_var_; + VarPtr add1_var_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc index ca9c90f4e5b..68baeeed992 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.cc @@ -16,36 +16,9 @@ #include "pre_activate/ascend/ir_fusion/lamb_next_right_rule.h" #include #include "pre_activate/common/helper.h" -#include "utils/utils.h" namespace mindspore { namespace opt { -namespace { -AnfNodePtr GetAdd1Node(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto add2_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(add2_cnode); - if (add2_cnode->inputs().size() != kAddInputNum) { - MS_LOG(ERROR) << "The input size of Add2 is not equal to " << kAddInputNum; - } - AnfNodePtr sqrt0 = add2_cnode->input(1); - MS_EXCEPTION_IF_NULL(sqrt0); - auto sqrt0_cnode = sqrt0->cast(); - MS_EXCEPTION_IF_NULL(sqrt0_cnode); - if (sqrt0_cnode->inputs().size() != kSqrtInputNum) { - MS_LOG(ERROR) << "The input size of Sqrt0 is not equal to " << kSqrtInputNum; - } - AnfNodePtr real_div1 = sqrt0_cnode->input(1); - MS_EXCEPTION_IF_NULL(real_div1); - auto real_div1_cnode = real_div1->cast(); - MS_EXCEPTION_IF_NULL(real_div1_cnode); - if (real_div1_cnode->inputs().size() != kMulInputNum) { - MS_LOG(ERROR) << "The input size of RealDiv1 is not equal to " << kMulInputNum; - } - return real_div1_cnode->input(1); -} -} // namespace - AnfNodePtr LambNextRightRule::CreateLambNextRightNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); @@ -79,7 +52,7 @@ const BaseRef LambNextRightRule::DefinePattern() const { const auto prim_sqrt = std::make_shared(kSqrtOpName); MS_EXCEPTION_IF_NULL(prim_sqrt); VectorRef mul3 = VectorRef({prim::kPrimMul, mul3_x_, VectorRef({prim::kPrimSquare, input0_})}); - VectorRef add1 = VectorRef({prim::kPrimTensorAdd, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); + VectorRef add1 = VectorRef({add1_var_, VectorRef({prim::kPrimMul, mul2_x_, input1_}), mul3}); return VectorRef( {prim::kPrimTensorAdd, VectorRef({prim_sqrt, VectorRef({prim::kPrimMul, add1, true_div1_recip_})}), add2_y_}); } @@ -91,7 +64,11 @@ const AnfNodePtr LambNextRightRule::Process(const FuncGraphPtr &func_graph, cons auto new_node = CreateLambNextRightNode(func_graph, equiv); MS_EXCEPTION_IF_NULL(new_node); // Set abstract of new node - AnfNodePtr add1 = GetAdd1Node(node); + auto iter_add1 = (*equiv).find(add1_var_); + if (iter_add1 == (*equiv).end()) { + MS_LOG(EXCEPTION) << "The equiv map is expected to contains the add1 var after matched."; + } + auto add1 = utils::cast(iter_add1->second); MS_EXCEPTION_IF_NULL(add1); AbstractBasePtrList new_node_abstract_list; new_node_abstract_list.push_back(add1->abstract()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h index f78be7460bb..3d15001da24 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/lamb_next_right_rule.h @@ -18,6 +18,8 @@ #include #include "pre_activate/common/optimizer.h" +#include "utils/utils.h" + namespace mindspore { namespace opt { class LambNextRightRule : public PatternProcessPass { @@ -29,7 +31,8 @@ class LambNextRightRule : public PatternProcessPass { mul2_x_(std::make_shared()), mul3_x_(std::make_shared()), true_div1_recip_(std::make_shared()), - add2_y_(std::make_shared()) {} + add2_y_(std::make_shared()), + add1_var_(std::make_shared(std::make_shared(prim::kPrimTensorAdd->name()))) {} ~LambNextRightRule() override = default; const BaseRef DefinePattern() const override; @@ -44,6 +47,7 @@ class LambNextRightRule : public PatternProcessPass { VarPtr mul3_x_; VarPtr true_div1_recip_; VarPtr add2_y_; + VarPtr add1_var_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.cc b/mindspore/ccsrc/pre_activate/common/optimizer.cc index 62cff76be01..0e74da3fe8a 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.cc +++ b/mindspore/ccsrc/pre_activate/common/optimizer.cc @@ -30,7 +30,8 @@ namespace mindspore { namespace opt { namespace { -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph); +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph); ValueNodePtr CreateValueNodeWithSexp(const BaseRef &sexp) { if (utils::isa(sexp)) { @@ -71,12 +72,20 @@ VarNodePtr CreateVarNodeWithSexp(const BaseRef &sexp, const BaseRef &graph) { return nullptr; } -AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph = false) { +AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph = false) { MS_LOG(DEBUG) << "SexpToNode sexp: " + sexp.ToString() + ", graph " + graph.ToString(); + MS_EXCEPTION_IF_NULL(primitive_vars); if (utils::isa(sexp)) { - return HandleSexpVector(sexp, graph, multigraph); + return HandleSexpVector(sexp, graph, primitive_vars, multigraph); } if (utils::isa(sexp)) { + auto var_ptr = utils::cast(sexp); + MS_EXCEPTION_IF_NULL(var_ptr); + if (var_ptr->primitive()) { + (*primitive_vars)[var_ptr->primitive()] = var_ptr; + return NewValueNode(var_ptr->primitive()); + } return CreateVarNodeWithSexp(sexp, graph); } if (utils::isa(sexp)) { @@ -89,13 +98,14 @@ AnfNodePtr SexpToNode(const BaseRef &sexp, const BaseRef &graph, bool multigraph return value_node; } -AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool multigraph) { +AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, PrimitiveVarMap *primitive_vars, + bool multigraph) { MS_LOG(DEBUG) << "HandleSexpVector sexp: " + sexp.ToString() + ", graph " + graph.ToString(); std::vector input_nodes; const auto &tuple = utils::cast(sexp); if (multigraph && utils::isa(graph)) { for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, std::make_shared("G"), true); + AnfNodePtr node = SexpToNode(x, std::make_shared("G"), primitive_vars, true); input_nodes.push_back(node); } VarPtr var_ptr = utils::cast(graph); @@ -103,7 +113,7 @@ AnfNodePtr HandleSexpVector(const BaseRef &sexp, const BaseRef &graph, bool mult } for (auto &x : tuple) { - AnfNodePtr node = SexpToNode(x, graph, multigraph); + AnfNodePtr node = SexpToNode(x, graph, primitive_vars, multigraph); input_nodes.push_back(node); } return CreateCNodeWithGraph(input_nodes, graph); @@ -166,7 +176,8 @@ PatternProcessPass::PatternProcessPass(const std::string &name, bool multigraph) multigraph_(multigraph), pattern_engine_(PatternEngine(std::make_shared(), std::function(AnfEqual), - std::function(CNodeTypeEqual))) {} + std::function(CNodeTypeEqual))), + primitive_vars_(std::make_shared()) {} const BaseRef PatternProcessPass::DefinePattern() const { VarPtr X = std::make_shared(); @@ -176,7 +187,7 @@ const BaseRef PatternProcessPass::DefinePattern() const { void PatternProcessPass::Build() { VarPtr fg = std::make_shared("RootG"); BaseRef pattern = std::move(DefinePattern()); - pattern_ = SexpToNode(pattern, fg, multigraph_); + pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_); } AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { @@ -185,7 +196,8 @@ AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNode } auto empty_equiv = std::make_shared(); - EquivPtr equiv = pattern_engine_.Match(pattern_, node, empty_equiv); + MS_EXCEPTION_IF_NULL(primitive_vars_); + EquivPtr equiv = pattern_engine_.Match(pattern_, node, *primitive_vars_, empty_equiv); if (equiv != nullptr && !equiv->empty()) { return Process(func_graph, node, equiv); } diff --git a/mindspore/ccsrc/pre_activate/common/optimizer.h b/mindspore/ccsrc/pre_activate/common/optimizer.h index 8ef0b6dc340..eade7f77896 100644 --- a/mindspore/ccsrc/pre_activate/common/optimizer.h +++ b/mindspore/ccsrc/pre_activate/common/optimizer.h @@ -19,6 +19,7 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/func_graph.h" @@ -46,6 +47,7 @@ class PatternProcessPass : public NodePass { AnfNodePtr pattern_ = nullptr; bool multigraph_ = true; PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; }; class GraphOptimizer { diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc b/mindspore/ccsrc/pre_activate/common/pattern_engine.cc index e2ff321a894..350332b9d1b 100644 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.cc +++ b/mindspore/ccsrc/pre_activate/common/pattern_engine.cc @@ -42,7 +42,7 @@ void Var::EnsureTag() { } } -bool operator==(const VarPtr& lhs, const VarPtr& rhs) { +bool operator==(const VarPtr &lhs, const VarPtr &rhs) { if (lhs->isa() && rhs->isa()) { CondVarPtr v1 = dyn_cast(lhs); CondVarPtr v2 = dyn_cast(rhs); @@ -63,7 +63,7 @@ std::string SeqVar::ToString() const { return buffer.str(); } -std::ostream& operator<<(std::ostream& os, const VarPtr& var) { +std::ostream &operator<<(std::ostream &os, const VarPtr &var) { if (var == nullptr) { os << ""; } else { @@ -73,10 +73,10 @@ std::ostream& operator<<(std::ostream& os, const VarPtr& var) { } template <> -std::ostream& operator<<(std::ostream& os, const Equiv& equiv) { +std::ostream &operator<<(std::ostream &os, const Equiv &equiv) { os << "[Equiv]" << "\n"; - for (auto& equiv_item : equiv) { + for (auto &equiv_item : equiv) { auto k = equiv_item.first; os << k << ":"; BaseRef x = equiv_item.second; @@ -104,7 +104,7 @@ std::ostream& operator<<(std::ostream& os, const Equiv& equiv) return os; } -static BaseRef GetVar(const BaseRef& x) { +static BaseRef GetVar(const BaseRef &x) { MS_LOG(DEBUG) << "getVar start :%s" + x.ToString(); if (utils::isa(x)) { auto node = utils::cast(x); @@ -129,7 +129,7 @@ static BaseRef GetVar(const BaseRef& x) { return x; } -EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) { +EquivPtr MatchOnVar(const BaseRef &pattern, const BaseRef &expr, EquivPtr equiv) { MS_LOG(DEBUG) << "MatchOnVar pattern " + pattern.ToString() + " expr: " + expr.ToString(); MS_EXCEPTION_IF_NULL(equiv); if (utils::isa(pattern)) { @@ -144,8 +144,8 @@ EquivPtr MatchOnVar(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) return nullptr; } -bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, - VectorRef* const values_expr) const { +bool PatternEngine::ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { MS_EXCEPTION_IF_NULL(values_expr); if (utils::isa(pattern_ref)) { *values_pattern = pattern_ref; @@ -155,12 +155,12 @@ bool PatternEngine::ToVector(const VectorRef& pattern_ref, const VectorRef& expr return false; } -bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref, VectorRef* const values_pattern, - VectorRef* const values_expr) const { +bool PatternEngine::ToVector(const BaseRef &pattern_ref, const BaseRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const { MS_EXCEPTION_IF_NULL(values_expr); // visitor to visite the list - auto appender_pattern = [](VectorRef& values) { - std::function fn = [&](const BaseRef& u) { + auto appender_pattern = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { values.push_back(GetVar(u)); return u; }; @@ -174,8 +174,8 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref return false; } - auto appender_expr = [](VectorRef& values) { - std::function fn = [&](const BaseRef& u) { + auto appender_expr = [](VectorRef &values) { + std::function fn = [&](const BaseRef &u) { values.push_back(u); return u; }; @@ -187,10 +187,10 @@ bool PatternEngine::ToVector(const BaseRef& pattern_ref, const BaseRef& expr_ref return visitor_->Visit(expr_ref, nullptr); } -static int GetSVarStartIndex(const VectorRef& values) { +static int GetSVarStartIndex(const VectorRef &values) { int index = -1; int count = 0; - for (auto& value : values) { + for (auto &value : values) { if (utils::isa(value) && utils::cast(value)->isa()) { if (index != -1) { MS_LOG(DEBUG) << "Multiple SVars in sequence"; @@ -203,7 +203,35 @@ static int GetSVarStartIndex(const VectorRef& values) { return index; } -EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const { +void UpdateEquivMap(const VectorRef &values_pattern, const BaseRef &expr_ref, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) { + if (equiv == nullptr || values_pattern.empty() || !utils::isa(values_pattern[0]) || + !utils::isa(expr_ref)) { + return; + } + auto real_node = utils::cast(expr_ref); + MS_EXCEPTION_IF_NULL(real_node); + if (!real_node->isa()) { + return; + } + auto prim_node = utils::cast(values_pattern[0]); + MS_EXCEPTION_IF_NULL(prim_node); + if (!IsValueNode(prim_node)) { + return; + } + ValuePtr value = GetValueNode(prim_node); + MS_EXCEPTION_IF_NULL(value); + auto prim = value->cast(); + MS_EXCEPTION_IF_NULL(prim); + auto iter = primitive_vars.find(prim); + if (iter == primitive_vars.end()) { + return; + } + (*equiv)[iter->second] = real_node; +} + +EquivPtr PatternEngine::AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const { int svar_index = GetSVarStartIndex(values_pattern); if (svar_index == kInvalidVarIndex) { return nullptr; @@ -229,12 +257,12 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR if (svar_index != -1 && i == IntToSize(svar_index)) { auto seq = std::vector(values_expr.begin() + svar_index, values_expr.begin() + svar_index + SizeToInt(diff)); - equiv = Match(values_pattern[svar_index], seq, equiv); + equiv = Match(values_pattern[svar_index], seq, primitive_vars, equiv); } else { if (svar_index != -1 && i > IntToSize(svar_index)) { expr_i = i + diff - 1; } - equiv = Match(values_pattern[i], values_expr[expr_i], equiv); + equiv = Match(values_pattern[i], values_expr[expr_i], primitive_vars, equiv); } if (equiv == nullptr) { return nullptr; @@ -243,7 +271,8 @@ EquivPtr PatternEngine::AlignSVar(const VectorRef& values_pattern, const VectorR return equiv; } -EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const { +EquivPtr PatternEngine::Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const { MS_LOG(DEBUG) << "-----[in Match]"; MS_LOG(DEBUG) << "GetVar w"; BaseRef pattern_ref = GetVar(pattern); @@ -292,10 +321,12 @@ EquivPtr PatternEngine::Match(const BaseRef& pattern, const BaseRef& expr, Equiv // 6. if any svar in both side, find the SeqVar index, // try to pack the Var s in std::vector to a Seq and match elements one by one. // check svar - return AlignSVar(values_pattern, values_expr, equiv); + equiv = AlignSVar(values_pattern, values_expr, primitive_vars, equiv); + UpdateEquivMap(values_pattern, expr_ref, primitive_vars, equiv); + return equiv; } -BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) const { +BaseRef PatternEngine::Replace(const BaseRef &pattern, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(equiv); MS_LOG(DEBUG) << "-----[in Replace]"; BaseRef ref = GetVar(pattern); @@ -304,7 +335,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co // w is var if (utils::isa(ref)) { - const VarPtr& var = utils::cast(ref); + const VarPtr &var = utils::cast(ref); auto iter = equiv->find(var); if (iter != equiv->end()) { out = iter->second; @@ -316,7 +347,7 @@ BaseRef PatternEngine::Replace(const BaseRef& pattern, const EquivPtr& equiv) co } // visitor to visit the list - std::function fn = [&, this, equiv](const BaseRef& u) { return Replace(u, equiv); }; + std::function fn = [&, this, equiv](const BaseRef &u) { return Replace(u, equiv); }; visitor_->SetFn(fn); BaseRef visit_out; diff --git a/mindspore/ccsrc/pre_activate/common/pattern_engine.h b/mindspore/ccsrc/pre_activate/common/pattern_engine.h index 432746332ff..858b1aecb88 100644 --- a/mindspore/ccsrc/pre_activate/common/pattern_engine.h +++ b/mindspore/ccsrc/pre_activate/common/pattern_engine.h @@ -31,6 +31,7 @@ #include #include #include +#include #include "pre_activate/common/visit.h" #include "ir/base.h" @@ -44,16 +45,19 @@ using CondVarPtr = std::shared_ptr; using SVarPtr = std::shared_ptr; const int kInvalidVarIndex = -2; -using ConditionFunc = std::function; +using ConditionFunc = std::function; // Base wildcard variable which could match any anf node. class Var : public Base { friend class VarHasher; public: - explicit Var(const std::string& tag = "") : tag_(tag) { EnsureTag(); } - Var(const Var& other) : Base(other), tag_(other.tag_) {} - virtual Var& operator=(const Var& other) { + explicit Var(std::string tag = "") : tag_(std::move(tag)), primitive_(nullptr) { EnsureTag(); } + explicit Var(const PrimitivePtr &primitive, std::string tag = "") : tag_(std::move(tag)), primitive_(primitive) { + EnsureTag(); + } + Var(const Var &other) : Base(other), tag_(other.tag_) {} + virtual Var &operator=(const Var &other) { if (&other == this) { return *this; } @@ -63,12 +67,13 @@ class Var : public Base { ~Var() override = default; MS_DECLARE_PARENT(Var, Base); - virtual bool matches(const BaseRef&) { return true; } + virtual bool matches(const BaseRef &) { return true; } - virtual bool operator==(const Var& other) const { return tag_ == other.tag_; } - bool operator!=(const Var& other) const { return !(&other == this); } + virtual bool operator==(const Var &other) const { return tag_ == other.tag_; } + bool operator!=(const Var &other) const { return !(&other == this); } std::string tag() const { return tag_; } + PrimitivePtr primitive() const { return primitive_; } std::string ToString() const override { std::ostringstream buffer; buffer << "Var(" << tag_ << ")"; @@ -80,12 +85,13 @@ class Var : public Base { void EnsureTag(); std::string tag_; + PrimitivePtr primitive_; }; // VarNode means variable node, a subclass of AnfNode class VarNode : public AnfNode { public: - VarNode(const VarPtr& value, const FuncGraphPtr& func_graph) : AnfNode(func_graph), var_(value) {} + VarNode(const VarPtr &value, const FuncGraphPtr &func_graph) : AnfNode(func_graph), var_(value) {} ~VarNode() override = default; MS_DECLARE_PARENT(VarNode, AnfNode); @@ -95,16 +101,16 @@ using VarNodePtr = std::shared_ptr; class VarHasher { public: - std::size_t operator()(const Var& var) const { return var.hash(); } + std::size_t operator()(const Var &var) const { return var.hash(); } }; // Condition Var, match an anf node when condition function return true. class CondVar : public Var { public: - explicit CondVar(const ConditionFunc& cond) : cond_fn_(cond) {} + explicit CondVar(const ConditionFunc &cond) : cond_fn_(cond) {} ~CondVar() override = default; MS_DECLARE_PARENT(CondVar, Var); - bool matches(const BaseRef& value) override { + bool matches(const BaseRef &value) override { MS_LOG(DEBUG) << "CondVarPtr match: " + value.ToString(); if (utils::isa(value)) { return false; @@ -124,55 +130,60 @@ class SeqVar : public Var { ~SeqVar() override = default; MS_DECLARE_PARENT(SeqVar, Var); explicit SeqVar(const VarPtr subvar) : subvar_(nullptr) { subvar_ = subvar; } - bool matches(const BaseRef& value) override { + bool matches(const BaseRef &value) override { // match Seq. if (utils::isa(value)) { - const Seq& seq = utils::cast(value); - return std::all_of(seq.begin(), seq.end(), [this](const BaseRef& v) { + const Seq &seq = utils::cast(value); + return std::all_of(seq.begin(), seq.end(), [this](const BaseRef &v) { auto eq = subvar_->matches(v); return eq; }); } return false; } - bool operator==(const SeqVar& other) const { return *subvar_ == *other.subvar_; } + bool operator==(const SeqVar &other) const { return *subvar_ == *other.subvar_; } std::string ToString() const override; private: VarPtr subvar_; }; -bool operator==(const VarPtr& lhs, const VarPtr& rhs); +bool operator==(const VarPtr &lhs, const VarPtr &rhs); -inline bool operator!=(const VarPtr& lhs, const VarPtr& rhs) { return !(lhs == rhs); } +inline bool operator!=(const VarPtr &lhs, const VarPtr &rhs) { return !(lhs == rhs); } -std::ostream& operator<<(std::ostream& os, const VarPtr& var); +std::ostream &operator<<(std::ostream &os, const VarPtr &var); using Equiv = std::map; using EquivPtr = std::shared_ptr; +using PrimitiveVarMap = std::unordered_map; +using PrimitiveVarMapPtr = std::shared_ptr; -inline bool DefaultTypeEq(const BaseRef& x, const BaseRef& y) { return x.type() == y.type(); } +inline bool DefaultTypeEq(const BaseRef &x, const BaseRef &y) { return x.type() == y.type(); } class PatternEngine { public: - PatternEngine(const std::shared_ptr& visitor, const std::function& eq, - const std::function& type_eq = DefaultTypeEq) + PatternEngine(const std::shared_ptr &visitor, + const std::function &eq, + const std::function &type_eq = DefaultTypeEq) : visitor_(visitor), eq_(eq), type_eq_(type_eq) {} ~PatternEngine() = default; - EquivPtr Match(const BaseRef& pattern, const BaseRef& expr, EquivPtr equiv) const; + EquivPtr Match(const BaseRef &pattern, const BaseRef &expr, const PrimitiveVarMap &primitive_vars, + EquivPtr equiv) const; // Replace pattern with equivalent - BaseRef Replace(const BaseRef& pattern, const EquivPtr& equiv) const; + BaseRef Replace(const BaseRef &pattern, const EquivPtr &equiv) const; private: - EquivPtr AlignSVar(const VectorRef& values_pattern, const VectorRef& values_expr, EquivPtr equiv) const; - bool ToVector(const BaseRef& pattern, const BaseRef& expr, VectorRef* const values_pattern, - VectorRef* const values_expr) const; - bool ToVector(const VectorRef& pattern_ref, const VectorRef& expr_ref, VectorRef* const values_pattern, - VectorRef* const values_expr) const; + EquivPtr AlignSVar(const VectorRef &values_pattern, const VectorRef &values_expr, + const PrimitiveVarMap &primitive_vars, EquivPtr equiv) const; + bool ToVector(const BaseRef &pattern, const BaseRef &expr, VectorRef *const values_pattern, + VectorRef *const values_expr) const; + bool ToVector(const VectorRef &pattern_ref, const VectorRef &expr_ref, VectorRef *const values_pattern, + VectorRef *const values_expr) const; std::shared_ptr visitor_; - std::function eq_; - std::function type_eq_; + std::function eq_; + std::function type_eq_; }; } // namespace mindspore namespace std { diff --git a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc index 9124f5cf747..7b0e2cc9db8 100644 --- a/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc +++ b/tests/ut/cpp/pre_activate/common/pattern_engine_test.cc @@ -40,6 +40,7 @@ class TestMatchEngine : public UT::Common { public: PatternEngine TU; EquivPtr equiv_null; + PrimitiveVarMap primitive_vars_null; }; TEST_F(TestMatchEngine, Var) { @@ -106,30 +107,30 @@ TEST_F(TestMatchEngine, MatchRaw_Var) { // common equiv_null->clear(); - d = TU.Match(v1, 1, equiv_null); + d = TU.Match(v1, 1, primitive_vars_null, equiv_null); ASSERT_EQ((*d)[v1], 1); equiv_null->clear(); (*equiv_null)[v1] = v2; - d = TU.Match(v1, 1, equiv_null); + d = TU.Match(v1, 1, primitive_vars_null, equiv_null); ASSERT_EQ(d->count(v2), std::size_t(1)); ASSERT_EQ((*d)[v2], 1); equiv_null->clear(); (*equiv_null)[v1] = v2; (*equiv_null)[v3] = 1; - d = TU.Match(v1, 1, equiv_null); + d = TU.Match(v1, 1, primitive_vars_null, equiv_null); ASSERT_EQ(d->count(v2), std::size_t(1)); ASSERT_EQ((*d)[v2], 1); equiv_null->clear(); - d = TU.Match(VectorRef({v1}), VectorRef({1}), equiv_null); + d = TU.Match(VectorRef({v1}), VectorRef({1}), primitive_vars_null, equiv_null); ASSERT_EQ(d->size(), std::size_t(1)); ASSERT_EQ(d->count(v1), std::size_t(1)); ASSERT_EQ((*d)[v1], 1); equiv_null->clear(); - ASSERT_EQ(TU.Match(1, 2, equiv_null), nullptr); + ASSERT_EQ(TU.Match(1, 2, primitive_vars_null, equiv_null), nullptr); } TEST_F(TestMatchEngine, MatchRaw_SVar) { @@ -139,22 +140,22 @@ TEST_F(TestMatchEngine, MatchRaw_SVar) { EquivPtr d; equiv_null->clear(); - d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), equiv_null); + d = TU.Match(VectorRef({sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null); ASSERT_EQ(d->size(), std::size_t(1)); ASSERT_EQ(d->count(sv1), std::size_t(1)); ASSERT_EQ(utils::cast((*d)[sv1]), Seq({1, 2})); equiv_null->clear(); - d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), equiv_null); + d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 2}), primitive_vars_null, equiv_null); ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ(utils::cast((*d)[sv1]), Seq({2})); equiv_null->clear(); - ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), equiv_null), nullptr); + ASSERT_EQ(TU.Match(VectorRef({sv1, sv2}), VectorRef({1, 2}), primitive_vars_null, equiv_null), nullptr); equiv_null->clear(); (*equiv_null)[sv1] = std::make_shared(PatternListType{1, 2}); - d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), equiv_null); + d = TU.Match(VectorRef({v1, sv1}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null); ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ((*d)[v1], 1); } @@ -167,13 +168,13 @@ TEST_F(TestMatchEngine, Match) { EquivPtr d; equiv_null->clear(); - d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), equiv_null); + d = TU.Match(VectorRef({v1, v1, v2}), VectorRef({1, 1, 2}), primitive_vars_null, equiv_null); ASSERT_EQ(d->size(), std::size_t(2)); ASSERT_EQ((*d)[v1], 1); ASSERT_EQ((*d)[v2], 2); equiv_null->clear(); - d = TU.Match(static_cast(1), static_cast(1), equiv_null); + d = TU.Match(static_cast(1), static_cast(1), primitive_vars_null, equiv_null); ASSERT_EQ(d, nullptr); } @@ -197,18 +198,19 @@ TEST_F(TestMatchEngine, Match_CondVar) { EquivPtr d; equiv_null->clear(); - d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast(1.0), -1}), equiv_null); + d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast(1.0), -1}), primitive_vars_null, equiv_null); ASSERT_GE(d->size(), std::size_t(0)); auto vfn = (*d)[vf]; ASSERT_EQ((*d)[vf], static_cast(1.0)); ASSERT_EQ((*d)[vn], -1); equiv_null->clear(); - d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast(-1.0)}), equiv_null); + d = TU.Match(VectorRef({vf, vn}), VectorRef({1, static_cast(-1.0)}), primitive_vars_null, equiv_null); ASSERT_EQ(d, nullptr); equiv_null->clear(); - d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast(1.0), static_cast(1)}), equiv_null); + d = TU.Match(VectorRef({vf, vn}), VectorRef({static_cast(1.0), static_cast(1)}), primitive_vars_null, + equiv_null); ASSERT_EQ(d, nullptr); }