From 0b172d76e48cf39897e6941983a075af137141d1 Mon Sep 17 00:00:00 2001 From: reku1997 Date: Fri, 10 Feb 2023 14:15:10 +0800 Subject: [PATCH] Optimize Ascend Backend Pass --- .../common/optimizer/graph_optimizer.cc | 4 + .../ccsrc/backend/common/optimizer/helper.cc | 12 +- .../common/optimizer/inplace_node_pass.cc | 56 ++ .../common/optimizer/inplace_node_pass.h | 49 ++ .../backend/common/optimizer/node_pass.cc | 241 ++++++- .../backend/common/optimizer/node_pass.h | 19 +- .../common/optimizer/pattern_to_pattern.cc | 253 ++++++- .../common/optimizer/pattern_to_pattern.h | 15 +- .../backend/common/pass/add_dropout_attrs.cc | 10 + .../backend/common/pass/add_dropout_attrs.h | 5 + .../common/pass/add_dynamic_shape_attr.cc | 9 +- .../common/pass/add_dynamic_shape_attr.h | 12 +- .../common/pass/conv_transpose_to_conv_bp.cc | 30 +- .../common/pass/conv_transpose_to_conv_bp.h | 12 +- .../pass/convert_attr_to_unify_mindir.cc | 7 +- .../pass/convert_attr_to_unify_mindir.h | 12 +- .../pass/convert_dynamic_broadcast_to.cc | 55 +- .../pass/convert_dynamic_broadcast_to.h | 17 +- .../common/pass/custom_op_reg_info_to_attr.cc | 53 +- .../common/pass/custom_op_reg_info_to_attr.h | 13 +- .../common/pass/flatten_concat_fission.cc | 7 + .../common/pass/flatten_concat_fission.h | 5 + .../pass/inplace_assign_for_custom_op.cc | 6 + .../pass/inplace_assign_for_custom_op.h | 5 + .../enhancer/add_attr_for_3d_graph.cc | 14 +- .../reselect_call_inline_format.cc | 6 +- .../ascend/optimizer/ge/reduce_axis_update.cc | 10 +- .../ir_fission/reduce_min_fission.cc | 18 +- .../optimizer/mindir/aicpu_lib_select.cc | 11 +- .../optimizer/mindir/aicpu_lib_select.h | 9 +- .../mindir/all_to_all_unify_mindir.cc | 12 + .../mindir/all_to_all_unify_mindir.h | 6 + .../mindir/avg_pool_grad_unify_mindir.cc | 107 ++- .../mindir/avg_pool_grad_unify_mindir.h | 12 +- .../optimizer/mindir/bn_grad_unify_mindir.cc | 58 +- .../optimizer/mindir/bn_grad_unify_mindir.h | 13 +- .../optimizer/mindir/dropout_unify_mindir.cc | 180 ++--- .../optimizer/mindir/dropout_unify_mindir.h | 16 +- .../optimizer/mindir/fse_decode_adjust.cc | 7 + .../optimizer/mindir/fse_decode_adjust.h | 4 + .../mindir/maxpool_to_maxpool_with_argmax.cc | 8 + .../mindir/maxpool_to_maxpool_with_argmax.h | 2 + .../maxpool_with_argmax_unify_mindir.cc | 60 +- .../mindir/maxpool_with_argmax_unify_mindir.h | 22 +- .../neighbor_exchange_v2_unify_mindir.cc | 13 + .../neighbor_exchange_v2_unify_mindir.h | 4 + .../mindir/optimizer_unify_output.cc | 170 +++-- .../optimizer/mindir/optimizer_unify_output.h | 52 +- .../mindir/quant_dtype_cast_adjust.cc | 6 + .../mindir/quant_dtype_cast_adjust.h | 4 + .../mindir/slice_grad_unify_mindir.cc | 52 +- .../mindir/slice_grad_unify_mindir.h | 12 +- .../mindir/space_batch_nd_attr_update.cc | 57 +- .../mindir/space_batch_nd_attr_update.h | 22 +- ..._cross_entropy_with_logits_unify_mindir.cc | 21 + ...x_cross_entropy_with_logits_unify_mindir.h | 10 + .../update_input_names_strided_slice_grad.cc | 47 +- .../update_input_names_strided_slice_grad.h | 13 +- mindspore/core/ir/manager.cc | 12 + mindspore/core/ir/manager.h | 25 +- mindspore/core/ir/value.cc | 4 +- .../fast_pattern_to_pattern_pass_test.cc | 617 ++++++++++++++++++ .../common/pattern_to_pattern_pass_test.cc | 67 +- .../common/pattern_to_pattern_pass_utils.h | 70 ++ 64 files changed, 2197 insertions(+), 563 deletions(-) create mode 100644 mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc create mode 100644 mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.h create mode 100644 tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc create mode 100644 tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_utils.h diff --git a/mindspore/ccsrc/backend/common/optimizer/graph_optimizer.cc b/mindspore/ccsrc/backend/common/optimizer/graph_optimizer.cc index 123b3ff895b..7b3211d0c53 100644 --- a/mindspore/ccsrc/backend/common/optimizer/graph_optimizer.cc +++ b/mindspore/ccsrc/backend/common/optimizer/graph_optimizer.cc @@ -45,6 +45,10 @@ FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_o std::vector func_graphs; func_graphs.push_back(func_graph); (void)TopoSort(func_graph->get_return()); + auto func_graph_index = manager->func_graph_index(func_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + func_graph_index->set_has_gen_index(false); + return func_graph; } } // namespace opt diff --git a/mindspore/ccsrc/backend/common/optimizer/helper.cc b/mindspore/ccsrc/backend/common/optimizer/helper.cc index 0875191b93c..a76a604a6af 100644 --- a/mindspore/ccsrc/backend/common/optimizer/helper.cc +++ b/mindspore/ccsrc/backend/common/optimizer/helper.cc @@ -704,8 +704,16 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) { if (b_value_ptr == nullptr) { MS_LOG(EXCEPTION) << "Value ptr is nullptr."; } - - return (*a_value_ptr) == (*b_value_ptr); + if (a_value_ptr->isa() && b_value_ptr->isa()) { + auto a_tensor_ptr = a_value_ptr->cast(); + auto b_tensor_ptr = b_value_ptr->cast(); + if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) { + MS_LOG(EXCEPTION) << "Cast value node ptr fail."; + } + return a_tensor_ptr->ValueEqual(*b_tensor_ptr); + } else { + return (*a_value_ptr) == (*b_value_ptr); + } } MS_LOG(DEBUG) << "check AnfNodePtr equal"; } diff --git a/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc new file mode 100644 index 00000000000..128492f49cf --- /dev/null +++ b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/common/optimizer/inplace_node_pass.h" + +namespace mindspore { +namespace opt { +AnfNodePtr InplaceNodePass::Run(const FuncGraphPtr &, const AnfNodePtr &node) { + std::vector pre_inputs; + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + pre_inputs.insert(pre_inputs.end(), inputs.begin(), inputs.end()); + } + bool ret = Process(node); + if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto inputs = cnode->inputs(); + if (inputs.size() != pre_inputs.size()) { + MS_LOG(EXCEPTION) << "InplaceNodePass ERROR, the pass modify node: " << node->DebugString() + << ", pass name: " << name(); + } + for (size_t i = 0; i < inputs.size(); i++) { + MS_EXCEPTION_IF_NULL(inputs[i]); + MS_EXCEPTION_IF_NULL(pre_inputs[i]); + if (!opt::AnfEqual(inputs[i], pre_inputs[i])) { + MS_LOG(EXCEPTION) << "InplaceNodePass ERROR, the pass modify node: " << node->DebugString() + << ", pass name: " << name() << ", before node " << i << ":" << inputs[i]->DebugString() + << ", after node " << i << ":" << pre_inputs[i]->DebugString(); + } + } + } + if (ret) { + return node; + } else { + return nullptr; + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.h b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.h new file mode 100644 index 00000000000..6403548a579 --- /dev/null +++ b/mindspore/ccsrc/backend/common/optimizer/inplace_node_pass.h @@ -0,0 +1,49 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_INPLACE_NODE_PASS_H +#define MINDSPORE_INPLACE_NODE_PASS_H + +#include +#include +#include + +#include "utils/hash_map.h" +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/primitive.h" +#include "backend/common/optimizer/pass_manager.h" +#include "backend/common/optimizer/pattern_engine.h" +#include "ir/graph_utils.h" +#include "utils/ms_utils.h" +#include "backend/common/optimizer/helper.h" +#include "backend/common/optimizer/graph_optimizer.h" +#include "include/backend/visible.h" + +namespace mindspore { +namespace opt { +class BACKEND_EXPORT InplaceNodePass : public NodePass { + public: + explicit InplaceNodePass(const std::string &name = "") : NodePass(name) {} + ~InplaceNodePass() override = default; + virtual bool Process(const AnfNodePtr &) const = 0; + AnfNodePtr Run(const FuncGraphPtr &, const AnfNodePtr &node) override; + bool IsFastPass() override { return true; } +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_INPLACE_NODE_PASS_H diff --git a/mindspore/ccsrc/backend/common/optimizer/node_pass.cc b/mindspore/ccsrc/backend/common/optimizer/node_pass.cc index f3ee701ebd2..6bd86b7f7d7 100644 --- a/mindspore/ccsrc/backend/common/optimizer/node_pass.cc +++ b/mindspore/ccsrc/backend/common/optimizer/node_pass.cc @@ -17,6 +17,9 @@ #include #include +#include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" #include "ir/manager.h" @@ -27,14 +30,18 @@ namespace mindspore { namespace opt { +namespace { const size_t kSwitchBranchIndex = 2; const size_t kCallArgsIndex = 1; const size_t kPartialArgsIndex = 1; +} // namespace void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_node_fg, const FuncGraphPtr &sub_graph) { MS_EXCEPTION_IF_NULL(call_node); MS_EXCEPTION_IF_NULL(call_node_fg); + MS_EXCEPTION_IF_NULL(sub_graph); + MS_EXCEPTION_IF_NULL(sub_graph->output()); call_node->set_abstract(sub_graph->output()->abstract()); auto manager = call_node_fg->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -58,9 +65,8 @@ void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_ } } -void AddOutputAndCallerToMap( - const CNodePtr &cnode, const FuncGraphPtr &fg, - mindspore::HashMap>> *out_caller_map) { +void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg, + mindspore::HashMap> *out_caller_map, bool is_add) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(out_caller_map); auto inputs = cnode->inputs(); @@ -73,28 +79,39 @@ void AddOutputAndCallerToMap( } auto switch_subgraph = GetValueNode(partial_inputs.at(kPartialArgsIndex)); MS_EXCEPTION_IF_NULL(switch_subgraph); - (*out_caller_map)[switch_subgraph->output()].emplace_back(cnode, fg); - UpdateCallerAbstract(cnode, fg, switch_subgraph); + if (is_add) { + (*out_caller_map)[switch_subgraph->output()].insert(cnode); + UpdateCallerAbstract(cnode, fg, switch_subgraph); + } else { + (*out_caller_map)[switch_subgraph->output()].erase(cnode); + } } else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { auto call_subgraph = GetValueNode(inputs.at(kCallArgsIndex)); MS_EXCEPTION_IF_NULL(call_subgraph); - (*out_caller_map)[call_subgraph->output()].emplace_back(cnode, fg); - UpdateCallerAbstract(cnode, fg, call_subgraph); + if (is_add) { + (*out_caller_map)[call_subgraph->output()].insert(cnode); + UpdateCallerAbstract(cnode, fg, call_subgraph); + } else { + (*out_caller_map)[call_subgraph->output()].erase(cnode); + } } } -void UpdateSubGraphCaller( - const AnfNodePtr &origin_output, const FuncGraphPtr &fg, - mindspore::HashMap>> *out_caller_map) { +void UpdateSubGraphCaller(const AnfNodePtr &origin_output, const FuncGraphPtr &fg, + mindspore::HashMap> *out_caller_map, + const mindspore::HashMap &node_to_fg) { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(fg->output()); auto find_iter = (*out_caller_map).find(origin_output); if (find_iter != (*out_caller_map).end()) { auto call_node_list = find_iter->second; (*out_caller_map).erase(find_iter); - for (auto &call_node_pair : call_node_list) { - auto call_node = call_node_pair.first; - auto call_node_fg = call_node_pair.second; + for (auto &call_node : call_node_list) { + auto fg_iter = node_to_fg.find(call_node); + if (fg_iter == node_to_fg.end()) { + MS_LOG(EXCEPTION) << "Node to Funcgraph find failed: " << call_node->fullname_with_scope(); + } + auto call_node_fg = fg_iter->second.lock(); UpdateCallerAbstract(call_node, call_node_fg, fg); } (*out_caller_map)[fg->output()] = call_node_list; @@ -111,22 +128,169 @@ void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspor } } -bool NodePass::Run(const FuncGraphPtr &func_graph) { +std::string GetCNodeKey(const AnfNodePtr &node) { + auto primitive = GetCNodePrimitive(node); + if (primitive != nullptr) { + return primitive->name(); + } else { + return ""; + } +} + +void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) { MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + if (func_graph_index->has_gen_index()) { + return; + } + + func_graph_index->set_has_gen_index(true); + func_graph_index->node_to_fg_.clear(); + func_graph_index->node_degree_.clear(); + func_graph_index->name_to_cnode_.clear(); + func_graph_index->subgraph_out_caller_map_.clear(); + FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - // maybe call subgraph many times - mindspore::HashMap>> subgraph_out_caller_map = {}; mindspore::HashSet seen_node; std::deque> todo{{func_graph->output(), func_graph}}; - bool changes = false; + while (!todo.empty()) { AnfNodePtr node = todo.front().first; + MS_EXCEPTION_IF_NULL(node); auto fg = todo.front().second; manager->AddFuncGraph(fg); todo.pop_front(); + + func_graph_index->node_to_fg_[node] = fg; + auto degree_iter = func_graph_index->node_degree_.find(node); + if (degree_iter == func_graph_index->node_degree_.end()) { + func_graph_index->node_degree_[node] = 1; + } else { + degree_iter->second++; + } + if (node->isa()) { + func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node); + } + + if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { + continue; + } + (void)seen_node.insert(node); + TraceGuard guard(std::make_shared(node->debug_info())); + + if (IsValueNode(node)) { + auto const_func_graph = GetValueNode(node); + MS_EXCEPTION_IF_NULL(const_func_graph); + if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + (void)todo.emplace_back(const_func_graph->output(), const_func_graph); + } + } else if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + ModifyOutputAndCallerToMap(cnode, fg, &func_graph_index->subgraph_out_caller_map_); + auto inputs = cnode->inputs(); + (void)std::for_each(inputs.begin(), inputs.end(), + [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); }); + } + } +} + +bool NodePass::ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph, + const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + MS_EXCEPTION_IF_NULL(manager); + auto iter = func_graph_index->node_to_fg_.find(node); + if (iter == func_graph_index->node_to_fg_.end()) { + MS_LOG(EXCEPTION) << "Node to Funcgraph map can't find node: " << node->fullname_with_scope(); + } + auto fg = iter->second.lock(); + TraceGuard guard(std::make_shared(node->debug_info())); + auto degree_iter = func_graph_index->node_degree_.find(node); + if (degree_iter == func_graph_index->node_degree_.end()) { + MS_LOG(EXCEPTION) << "Node degree map can't find node: " << node->fullname_with_scope(); + } + auto degree = degree_iter->second; + if (degree == 0 && node != func_graph->output()) { + return false; + } + // we may update return value in some pass. + MS_EXCEPTION_IF_NULL(fg); + auto origin_output = fg->output(); + MS_EXCEPTION_IF_NULL(origin_output); + auto origin_abstract = origin_output->abstract(); + AnfNodePtr new_node = Run(fg, node); + bool change = (new_node != nullptr); + MS_EXCEPTION_IF_NULL(fg->output()); + if (origin_abstract != fg->output()->abstract()) { + UpdateSubGraphCaller(origin_output, fg, &func_graph_index->subgraph_out_caller_map_, func_graph_index->node_to_fg_); + } + if (new_node != nullptr && new_node != node) { + (void)manager->Replace(node, new_node); + // if replaced node is end_goto, refresh relative params in kernel graph + auto kernel_graph = fg->cast>(); + if (kernel_graph != nullptr && node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto end_label = kernel_graph->get_end_goto(); + if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) { + kernel_graph->set_end_goto(new_node->cast()); + } + } + AfterProcess(node, new_node, fg, func_graph_index); + } + return change; +} + +bool NodePass::ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + if (!func_graph_index->has_gen_index()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, func graph has not gen index, pass name: " << name(); + } + auto src_pattern_root_name = GetPatternRootPrimitiveName(); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + bool changes = false; + + std::vector cand_node; + if (!src_pattern_root_name.empty()) { + auto cnode_iter = func_graph_index->name_to_cnode_.find(src_pattern_root_name); + if (cnode_iter == func_graph_index->name_to_cnode_.end()) { + return false; + } + std::copy(cnode_iter->second.begin(), cnode_iter->second.end(), std::back_inserter(cand_node)); + } else { + for (const auto &kv : func_graph_index->name_to_cnode_) { + std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(cand_node)); + } + } + for (const auto &node : cand_node) { + auto change = ProcessFastPassNode(node, func_graph, func_graph_index, manager); + changes = changes || change; + } + return changes; +} + +bool NodePass::ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(manager); + bool changes = false; + + // maybe call subgraph many times + mindspore::HashMap> subgraph_out_caller_map = {}; + mindspore::HashMap node_to_fg = {}; + mindspore::HashSet seen_node; + std::deque> todo{{func_graph->output(), func_graph}}; + while (!todo.empty()) { + AnfNodePtr node = todo.front().first; + auto fg = todo.front().second; + MS_EXCEPTION_IF_NULL(node); + manager->AddFuncGraph(fg); + todo.pop_front(); + node_to_fg[node] = fg; if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) { continue; } @@ -140,7 +304,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { AnfNodePtr new_node = Run(fg, node); bool change = (new_node != nullptr); if (origin_abstract != fg->output()->abstract()) { - UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map); + UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map, node_to_fg); } if (new_node != nullptr && new_node != node) { SkipSameOp(node, new_node, &seen_node); @@ -171,15 +335,44 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) { } auto cnode = new_node->cast(); MS_EXCEPTION_IF_NULL(cnode); - AddOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map); + ModifyOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map); auto inputs = cnode->inputs(); - (void)std::for_each(inputs.begin(), inputs.end(), [&fg, &todo](AnfNodePtr &node) { - (void)todo.emplace_back(std::pair(node, fg)); - }); + (void)std::for_each(inputs.begin(), inputs.end(), + [&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); }); } changes = changes || change; } return changes; } + +bool NodePass::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + auto func_graph_index = manager->func_graph_index(func_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + + if (IsFastPass()) { + MS_LOG(INFO) << "Run fast pass: " << name(); + GenIndex(func_graph, func_graph_index); + return ProcessFastPass(func_graph, func_graph_index); + } + if (func_graph_index->has_gen_index()) { + auto ret = MustExistPrimitiveName(); + for (const auto &primtive_name : ret) { + auto cnode_iter = func_graph_index->name_to_cnode_.find(primtive_name); + if (cnode_iter == func_graph_index->name_to_cnode_.end()) { + return false; + } + } + if (!ret.empty()) { + MS_LOG(INFO) << "Skip pass fail, run pass: " << name(); + } + } + func_graph_index->set_has_gen_index(false); + + return ProcessPass(func_graph, manager); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/optimizer/node_pass.h b/mindspore/ccsrc/backend/common/optimizer/node_pass.h index 96823fa4cf8..5c4f6e38895 100644 --- a/mindspore/ccsrc/backend/common/optimizer/node_pass.h +++ b/mindspore/ccsrc/backend/common/optimizer/node_pass.h @@ -17,6 +17,8 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ #include #include +#include +#include #include "backend/common/optimizer/pass.h" #include "include/backend/visible.h" @@ -28,10 +30,25 @@ class BACKEND_EXPORT NodePass : public Pass { public: explicit NodePass(const std::string &name) : Pass(name) {} ~NodePass() override = default; - virtual bool Run(const FuncGraphPtr &func_graph); + bool Run(const FuncGraphPtr &func_graph) override; + virtual bool IsFastPass() { return false; } + virtual void AfterProcess(const AnfNodePtr &, const AnfNodePtr &, const FuncGraphPtr &, const FuncGraphIndexPtr &) {} + virtual std::string GetPatternRootPrimitiveName() { return ""; } virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0; + virtual std::vector MustExistPrimitiveName() const { return {}; } + + private: + bool ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph, + const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager); + bool ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); + bool ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager); }; using NodePassPtr = std::shared_ptr; +void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index); +void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg, + mindspore::HashMap> *out_caller_map, + bool is_add = true); +std::string GetCNodeKey(const AnfNodePtr &node); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_ diff --git a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc index 42b7c98bdb4..25343832b84 100644 --- a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc +++ b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.cc @@ -16,7 +16,10 @@ #include "backend/common/optimizer/pattern_to_pattern.h" #include +#include +#include #include "ir/manager.h" +#include "include/common/utils/anfalgo.h" namespace mindspore { namespace opt { @@ -60,15 +63,19 @@ const std::vector &PatternMap::GetSeq(const std::string &name) const } bool PatternMap::Emplace(const std::string &name, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); name_set_.insert(name); if (seq_map_.find(name) != seq_map_.end()) { MS_LOG(EXCEPTION) << "Var Key: " << name << " should not be in SeqVarMap."; } + opt_scope_.insert(node); + auto iter = node_map_.find(name); if (iter == node_map_.end()) { node_map_.emplace(name, node); } else if (!opt::AnfEqual(node, iter->second)) { + MS_EXCEPTION_IF_NULL(iter->second); MS_LOG(INFO) << "The value of key: " << name << " is not equal to origin value, value: " + node->fullname_with_scope() << " origin value: " << iter->second->fullname_with_scope(); @@ -83,6 +90,10 @@ bool PatternMap::Emplace(const std::string &name, const std::vector MS_LOG(EXCEPTION) << "SeqVar Key: " << name << " should not be in VarMap."; } + for (const auto &node : v) { + opt_scope_.insert(node); + } + auto iter = seq_map_.find(name); if (iter == seq_map_.end()) { seq_map_.emplace(name, v); @@ -96,6 +107,8 @@ bool PatternMap::Emplace(const std::string &name, const std::vector } for (size_t i = 0; i < v.size(); i++) { + MS_EXCEPTION_IF_NULL(v[i]); + MS_EXCEPTION_IF_NULL(origin_v[i]); if (!opt::AnfEqual(v[i], origin_v[i])) { MS_LOG(INFO) << "The value of key: " << name << " is not equal to origin value, value: " + v[i]->fullname_with_scope() @@ -181,6 +194,7 @@ BaseRef SrcPattern::GetRoot() const { const Seq &GetSeq(const std::string &pattern_name, const std::string &node_name, const VarPtr &var, const EquivPtr &equiv) { + MS_EXCEPTION_IF_NULL(equiv); auto equiv_iter = equiv->find(var); if (equiv_iter == equiv->end()) { MS_LOG(EXCEPTION) << "The SeqVar Key: " << pattern_name << " is not in EquivMap, node name: " << node_name; @@ -204,6 +218,7 @@ bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv MS_EXCEPTION_IF_CHECK_FAIL(seq.size() == IntToSize(0), "Match Failed, need zero seq, but get seq length: " + std::to_string(seq.size()) + ", node name: " + name); std::vector v; + MS_EXCEPTION_IF_NULL(m_); if (!m_->Emplace(pattern_node.name_, v)) { return false; } @@ -214,6 +229,9 @@ bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv } bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv) { + MS_EXCEPTION_IF_NULL(m_); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(equiv); auto input_iter = inputs_map_.find(name); if (input_iter == inputs_map_.end()) { MS_LOG(EXCEPTION) << "Key: " << name << " is not a CNode."; @@ -234,6 +252,8 @@ bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const Eq auto &match_node = cnode_inputs[now_match]; if (pattern_node.type_ == "prim") { // prim + MS_EXCEPTION_IF_NULL(pattern_node.p_); + MS_EXCEPTION_IF_NULL(match_node); if (!opt::AnfEqual(pattern_node.p_, match_node)) { MS_LOG(EXCEPTION) << "The value of Primitive is not equal to matched value, pattern value: " + pattern_node.p_->ToString() @@ -296,6 +316,7 @@ bool SrcPattern::build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list &inputs, const BuildCNodeFunc &buildfunc) { + MS_EXCEPTION_IF_NULL(m_); if (fail_) { return *this; } @@ -343,10 +364,12 @@ DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list << ", CNode: " << name; } for (size_t i = 0; i < anf_inputs.size(); i++) { + MS_EXCEPTION_IF_NULL(anf_inputs[i]); + MS_EXCEPTION_IF_NULL(cnode->input(i)); if (!opt::AnfEqual(anf_inputs[i], cnode->input(i))) { MS_LOG(EXCEPTION) << "The actual input does not correspond to the input of the pattern, the input index: " << i - << ", actual input: " << anf_inputs[i]->fullname_with_scope() - << ", pattern input: " << new_node->cast()->input(i)->fullname_with_scope() + << ", actual input: " << anf_inputs[i]->DebugString() + << ", pattern input: " << new_node->cast()->input(i)->DebugString() << ", CNode: " << name; } } @@ -360,6 +383,7 @@ DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list } DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &buildfunc) { + MS_EXCEPTION_IF_NULL(m_); if (fail_) { return *this; } @@ -379,6 +403,7 @@ DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &b } void DstPattern::clear() { + MS_EXCEPTION_IF_NULL(m_); fail_ = false; root_ = nullptr; m_->Erase(dst_set_); @@ -406,12 +431,28 @@ UnpackNode &UnpackNode::operator=(const std::string &name) { return *this; } -AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { +AnfNodePtr PatternToPatternPass::GetSrcPatternRoot() { if (src_pattern_root_ == nullptr) { DefineSrcPattern(&src_pattern_); VarPtr fg = std::make_shared("RootG"); src_pattern_root_ = SexpToNode(src_pattern_.GetRoot(), fg, primitive_vars_.get(), multigraph_); } + return src_pattern_root_; +} + +std::string PatternToPatternPass::GetPatternRootPrimitiveName() { + auto src_pattern_root = GetSrcPatternRoot(); + auto prim = GetCNodePrimitive(src_pattern_root); + if (prim != nullptr) { + return prim->name(); + } + return ""; +} + +AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + if (src_pattern_root_ == nullptr) { + (void)GetSrcPatternRoot(); + } auto primitive = GetCNodePrimitive(src_pattern_root_); if (IsPrimitiveCNode(node, primitive)) { @@ -435,11 +476,217 @@ AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNo return nullptr; } +namespace { +const auto kStageZero = 0; +const auto kStageOne = 1; +const auto kStageTwo = 2; + +void DeleteCNode(const AnfNodePtr &node, const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(func_graph_index); + if (node->isa()) { + auto name_to_cnode_iter = func_graph_index->name_to_cnode_.find(GetCNodeKey(node)); + if (name_to_cnode_iter == func_graph_index->name_to_cnode_.end()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find cnode_name: " + << common::AnfAlgo::GetCNodeName(node); + } + auto &cnode_set = name_to_cnode_iter->second; + auto cnode_set_iter = cnode_set.find(node); + if (cnode_set_iter == cnode_set.end()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find node: " << node->fullname_with_scope(); + } + cnode_set.erase(cnode_set_iter); + ModifyOutputAndCallerToMap(node->cast(), sub_graph, &func_graph_index->subgraph_out_caller_map_, false); + } +} + +void AppendChild(const AnfNodePtr &node, const FuncGraphPtr &fg, + std::queue> *anf_q) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(fg); + MS_EXCEPTION_IF_NULL(anf_q); + if (IsValueNode(node)) { + auto const_func_graph = GetValueNode(node); + MS_EXCEPTION_IF_NULL(const_func_graph); + if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + anf_q->emplace(const_func_graph->output(), const_func_graph); + } + } else if (node->isa()) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (const auto &input_node : cnode->inputs()) { + anf_q->emplace(input_node, fg); + } + } +} + +bool DelSrcPattern(const std::pair &top, const AnfNodePtr &root, + const mindspore::HashSet &opt_scope, + std::set> *need_delete, + const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(need_delete); + MS_EXCEPTION_IF_NULL(func_graph_index); + auto node = top.first; + auto fg = top.second; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(fg); + if (node != root) { + auto degree_iter = func_graph_index->node_degree_.find(node); + if (degree_iter == func_graph_index->node_degree_.end()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() << " not in degree map"; + } + if (degree_iter->second <= 0) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() + << " degree error, degree: " << degree_iter->second; + } + degree_iter->second--; + if (degree_iter->second > 0) { + return false; + } + } + if (opt_scope.find(node) == opt_scope.end()) { + (*need_delete).insert({node, fg}); + return false; + } + + DeleteCNode(node, fg, func_graph_index); + return true; +} + +bool AddDstPattern(const std::pair &top, const AnfNodePtr &root, + const mindspore::HashSet &opt_scope, + std::set> *need_delete, + const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(root); + MS_EXCEPTION_IF_NULL(need_delete); + MS_EXCEPTION_IF_NULL(func_graph_index); + auto node = top.first; + auto fg = top.second; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(fg); + if (node->isa()) { + ModifyOutputAndCallerToMap(node->cast(), fg, &func_graph_index->subgraph_out_caller_map_); + func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node); + func_graph_index->node_to_fg_[node] = fg; + } + + if (node != root) { + auto degree_iter = func_graph_index->node_degree_.find(node); + if (degree_iter == func_graph_index->node_degree_.end()) { + func_graph_index->node_degree_[node] = 0; + degree_iter = func_graph_index->node_degree_.find(node); + } + degree_iter->second++; + if (degree_iter->second != 1) { + return false; + } + } + if (opt_scope.find(node) == opt_scope.end()) { + (*need_delete).erase({node, fg}); + return false; + } + return true; +} + +bool DelCascadeNode(const std::pair &top, + std::set> *need_delete, + const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(need_delete); + MS_EXCEPTION_IF_NULL(func_graph_index); + auto node = top.first; + auto fg = top.second; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(fg); + if ((*need_delete).find({node, fg}) == (*need_delete).end()) { + auto degree_iter = func_graph_index->node_degree_.find(node); + if (degree_iter == func_graph_index->node_degree_.end()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() << " not in degree map"; + } + if (degree_iter->second <= 0) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() + << " degree error, degree: " << degree_iter->second; + } + degree_iter->second--; + if (degree_iter->second > 0) { + return false; + } + } + + DeleteCNode(node, fg, func_graph_index); + return true; +} + +void BFS(const AnfNodePtr &root, const FuncGraphPtr &sub_graph, const mindspore::HashSet &opt_scope, + std::set> *need_delete, const FuncGraphIndexPtr &func_graph_index, + size_t stage) { + std::queue> anf_q; + + if (stage == kStageZero || stage == kStageOne) { + anf_q.emplace(root, sub_graph); + } else if (stage == kStageTwo) { + for (const auto &p : (*need_delete)) { + anf_q.push(p); + } + } else { + MS_LOG(EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage; + } + + while (!anf_q.empty()) { + auto top = anf_q.front(); + anf_q.pop(); + + bool ret = false; + if (stage == kStageZero) { + ret = DelSrcPattern(top, root, opt_scope, need_delete, func_graph_index); + } else if (stage == kStageOne) { + ret = AddDstPattern(top, root, opt_scope, need_delete, func_graph_index); + } else if (stage == kStageTwo) { + ret = DelCascadeNode(top, need_delete, func_graph_index); + } else { + MS_LOG(EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage; + } + if (!ret) { + continue; + } + + AppendChild(top.first, top.second, &anf_q); + } +} +} // namespace + +void PatternToPatternPass::AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node, + const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) { + MS_EXCEPTION_IF_NULL(m_); + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + MS_EXCEPTION_IF_NULL(sub_graph); + MS_EXCEPTION_IF_NULL(func_graph_index); + std::set> need_delete; + auto &opt_scope = m_->GetOptScope(); + + auto old_node_iter = func_graph_index->node_degree_.find(old_node); + if (old_node_iter == func_graph_index->node_degree_.end()) { + MS_LOG(EXCEPTION) << "ProcessFastPass Error, old_node: " << old_node->fullname_with_scope() << " not in degree map"; + } + auto origin_degree = old_node_iter->second; + + func_graph_index->node_degree_[new_node] = origin_degree; + func_graph_index->node_degree_[old_node] = 0; + + BFS(old_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageZero); + BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageOne); + BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageTwo); +} + std::vector PatternToPatternPass::Unpacking(const std::string &s) { + MS_EXCEPTION_IF_NULL(m_); auto v = m_->GetSeq(s); std::vector ret; std::transform(v.begin(), v.end(), std::back_inserter(ret), [](const AnfNodePtr &node) { return UnpackNode(node); }); return ret; } + +bool PatternToPatternPass::IsFastPass() { return is_fast_pass_; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.h b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.h index a85cc7d3521..e7c968d62ee 100644 --- a/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.h +++ b/mindspore/ccsrc/backend/common/optimizer/pattern_to_pattern.h @@ -43,11 +43,13 @@ class BACKEND_EXPORT PatternMap { void Clear(); bool Check(const std::string &name, const AnfNodePtr &node) const; void Erase(const mindspore::HashSet &del_set); + const mindspore::HashSet &GetOptScope() const { return opt_scope_; } private: mindspore::HashSet name_set_; mindspore::HashMap node_map_; mindspore::HashMap> seq_map_; + mindspore::HashSet opt_scope_; }; using PatternMapPtr = std::shared_ptr; @@ -163,16 +165,22 @@ class BACKEND_EXPORT DstPattern { class BACKEND_EXPORT PatternToPatternPass : public PatternPass { public: - explicit PatternToPatternPass(const std::string &name = "", bool multigraph = true) + explicit PatternToPatternPass(const std::string &name = "", bool is_fast_pass = false, bool multigraph = true) : PatternPass(name, multigraph), m_(std::make_shared()), src_pattern_(SrcPattern(m_)), - dst_pattern_(DstPattern(m_)) {} + dst_pattern_(DstPattern(m_)), + is_fast_pass_(is_fast_pass) {} ~PatternToPatternPass() override = default; virtual void DefineSrcPattern(SrcPattern *src_pattern) = 0; virtual void DefineDstPattern(DstPattern *dst_pattern) = 0; - virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const = 0; + virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { return true; } + bool IsFastPass() override; + AnfNodePtr GetSrcPatternRoot(); + std::string GetPatternRootPrimitiveName() override; AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override; + void AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const FuncGraphPtr &sub_graph, + const FuncGraphIndexPtr &func_graph_index) override; std::vector Unpacking(const std::string &s); private: @@ -180,6 +188,7 @@ class BACKEND_EXPORT PatternToPatternPass : public PatternPass { SrcPattern src_pattern_; DstPattern dst_pattern_; AnfNodePtr src_pattern_root_ = nullptr; + bool is_fast_pass_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.cc b/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.cc index 2cabc952196..12b7129c9ac 100644 --- a/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.cc +++ b/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.cc @@ -15,7 +15,11 @@ */ #include "backend/common/pass/add_dropout_attrs.h" + #include +#include +#include + #include "mindspore/core/ops/core_ops.h" #include "include/common/utils/anfalgo.h" @@ -65,6 +69,12 @@ const AnfNodePtr AddDropoutAttrs::Process(const FuncGraphPtr &func_graph, const return cnode; } +std::vector AddDropoutAttrs::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimDropout->name()); + return ret; +} + const BaseRef AddDropoutAttrs::DefinePattern() const { VarPtr Xs = std::make_shared(); return VectorRef({prim::kPrimDropout, Xs}); diff --git a/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.h b/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.h index 86e0cabd687..67fb84ecdf5 100644 --- a/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.h +++ b/mindspore/ccsrc/backend/common/pass/add_dropout_attrs.h @@ -17,6 +17,8 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_DROPOUT_ATTRS_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_DROPOUT_ATTRS_H_ +#include +#include #include "backend/common/optimizer/optimizer.h" namespace mindspore { @@ -27,6 +29,9 @@ class AddDropoutAttrs : public PatternProcessPass { ~AddDropoutAttrs() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.cc b/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.cc index dda4929ad72..377019be750 100644 --- a/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.cc +++ b/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.cc @@ -21,17 +21,16 @@ namespace mindspore { namespace opt { -const AnfNodePtr AddDynamicShapeAttr::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); +bool AddDynamicShapeAttr::Process(const AnfNodePtr &node) const { if (common::AnfAlgo::IsDynamicShape(node)) { + auto func_graph = node->func_graph(); MS_LOG(DEBUG) << "Set Dynamic Shape Attr to Node:" << node->fullname_with_scope(); auto kernel_graph = func_graph->cast(); MS_EXCEPTION_IF_NULL(kernel_graph); kernel_graph->SetGraphDynamicAttr(true); + return true; } - return node; + return false; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.h b/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.h index 1276249bfc1..a1fadd603e6 100644 --- a/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.h +++ b/mindspore/ccsrc/backend/common/pass/add_dynamic_shape_attr.h @@ -16,17 +16,13 @@ #ifndef MINDSPORE_ADD_DYNAMIC_SHAPE_ATTR_H #define MINDSPORE_ADD_DYNAMIC_SHAPE_ATTR_H -#include -#include "ir/anf.h" -#include "include/common/utils/convert_utils.h" -#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/inplace_node_pass.h" namespace mindspore { namespace opt { -class AddDynamicShapeAttr : public PatternProcessPass { +class AddDynamicShapeAttr : public InplaceNodePass { public: - explicit AddDynamicShapeAttr(bool multigraph = true) : PatternProcessPass("add_dynamic_shape_attr", multigraph) {} - ~AddDynamicShapeAttr() override = default; - const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + AddDynamicShapeAttr() : InplaceNodePass("add_dynamic_shape_attr") {} + bool Process(const AnfNodePtr &node) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.cc b/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.cc index 4aacadbc91d..138961e46c3 100644 --- a/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.cc +++ b/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.cc @@ -25,18 +25,12 @@ namespace mindspore { namespace opt { namespace { constexpr size_t kCNodePrimitiveIdx = 0; -} +constexpr auto kXs = "Xs"; +constexpr auto kMConv2dTrans = "m_conv2d_trans"; +constexpr auto kRConv2dBp = "r_conv2d_bp"; -const BaseRef ConvTransposeToConvBackpropInputPass::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto conv_transpose = std::make_shared(kConv2DTransposeOpName); - return VectorRef({conv_transpose, Xs}); -} - -const AnfNodePtr ConvTransposeToConvBackpropInputPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); +AnfNodePtr BuildConv2DBackpropInput(const PatternMap &m, const AnfNodePtr &default_node) { + auto node = m.Get(kMConv2dTrans); auto conv_transpose = node->cast(); MS_EXCEPTION_IF_NULL(conv_transpose); @@ -51,5 +45,19 @@ const AnfNodePtr ConvTransposeToConvBackpropInputPass::Process(const FuncGraphPt return node; } +} // namespace + +bool ConvTransposeToConvBackpropInputPass::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, + const AnfNodePtr &) const { + return true; +} + +void ConvTransposeToConvBackpropInputPass::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddSeqVar(kXs).AddCNode(kMConv2dTrans, {prim::kPrimConv2DTranspose, kXs}); +} + +void ConvTransposeToConvBackpropInputPass::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRConv2dBp, {prim::kPrimConv2DBackpropInput, kXs}, BuildConv2DBackpropInput); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.h b/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.h index aa41f69eb0d..7bbdc62a3b5 100644 --- a/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.h +++ b/mindspore/ccsrc/backend/common/pass/conv_transpose_to_conv_bp.h @@ -17,17 +17,17 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONV_TRANSPOSE_TO_CONV_BP_H_ #include -#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class ConvTransposeToConvBackpropInputPass : public PatternProcessPass { +class ConvTransposeToConvBackpropInputPass : public PatternToPatternPass { public: - explicit ConvTransposeToConvBackpropInputPass(bool multigraph = true) - : PatternProcessPass("conv_transpose_to_conv_backprop_input", multigraph) {} + ConvTransposeToConvBackpropInputPass() : PatternToPatternPass("conv_transpose_to_conv_backprop_input", true) {} ~ConvTransposeToConvBackpropInputPass() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.cc b/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.cc index 328fc94b1bb..516d201ce61 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.cc @@ -23,10 +23,9 @@ namespace mindspore { namespace opt { -const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const AnfNodePtr &node, - const EquivPtr &) const { +bool ConvertAttrToUnifyMindIR::Process(const AnfNodePtr &node) const { if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) { - return nullptr; + return false; } auto cnode = node->cast(); @@ -51,7 +50,7 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A } } - return node; + return true; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.h b/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.h index 81080caf2be..7e14c982b90 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.h +++ b/mindspore/ccsrc/backend/common/pass/convert_attr_to_unify_mindir.h @@ -16,17 +16,15 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_ATTR_TO_UNIFY_MINDIR_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_ATTR_TO_UNIFY_MINDIR_H_ -#include "ir/anf.h" -#include "backend/common/optimizer/optimizer.h" +#include +#include "backend/common/optimizer/inplace_node_pass.h" namespace mindspore { namespace opt { -class ConvertAttrToUnifyMindIR : public PatternProcessPass { +class ConvertAttrToUnifyMindIR : public InplaceNodePass { public: - explicit ConvertAttrToUnifyMindIR(bool multigraph = true) - : PatternProcessPass("convert_attr_to_unify_mindir", multigraph) {} - ~ConvertAttrToUnifyMindIR() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override; + ConvertAttrToUnifyMindIR() : InplaceNodePass("convert_attr_to_unify_mindir") {} + bool Process(const AnfNodePtr &node) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc index bad0d308d1f..4cae1f463dd 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.cc @@ -17,32 +17,49 @@ #include #include "backend/common/pass/convert_dynamic_broadcast_to.h" #include "ir/anf.h" -#include "backend/common/optimizer/optimizer.h" #include "include/common/utils/anfalgo.h" #include "backend/common/optimizer/helper.h" namespace mindspore { namespace opt { -const AnfNodePtr ConvertDynamicBroadcastTo::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); +namespace { +const auto kV = "V"; +const auto kMBroadcastTo = "m_broadcast_to"; +const auto kRBroadcastTo = "r_broadcast_to"; +AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) { + auto node = m.Get(kMBroadcastTo); MS_EXCEPTION_IF_NULL(node); - auto node_name = common::AnfAlgo::GetCNodeName(node); - if (node_name == prim::kPrimDynamicBroadcastTo->name() && !common::AnfAlgo::IsDynamicShape(node)) { - auto broadcast_to_op_name = prim::kPrimBroadcastTo->name(); - auto ori_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(ori_cnode); - auto input_x = common::AnfAlgo::GetInputNode(ori_cnode, 0); - CNodePtr broadcast_to_node = - opt::NewCNode({NewValueNode(std::make_shared(broadcast_to_op_name)), input_x}, func_graph, {node}); - MS_EXCEPTION_IF_NULL(broadcast_to_node); - broadcast_to_node->set_abstract(node->abstract()); - auto shape_ptr = node->abstract()->BuildShape()->cast(); - MS_EXCEPTION_IF_NULL(shape_ptr); - common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(shape_ptr->shape()), broadcast_to_node); - return broadcast_to_node; + auto broadcast_to_op_name = prim::kPrimBroadcastTo->name(); + auto ori_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(ori_cnode); + auto input_x = common::AnfAlgo::GetInputNode(ori_cnode, 0); + auto func_graph = node->func_graph(); + CNodePtr broadcast_to_node = + opt::NewCNode({NewValueNode(std::make_shared(broadcast_to_op_name)), input_x}, func_graph, {node}); + MS_EXCEPTION_IF_NULL(broadcast_to_node); + broadcast_to_node->set_abstract(node->abstract()); + auto shape_ptr = node->abstract()->BuildShape()->cast(); + MS_EXCEPTION_IF_NULL(shape_ptr); + common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(shape_ptr->shape()), broadcast_to_node); + return broadcast_to_node; +} +} // namespace + +bool ConvertDynamicBroadcastTo::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, + const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + if (!common::AnfAlgo::IsDynamicShape(node)) { + return true; } - return node; + return false; +} + +void ConvertDynamicBroadcastTo::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kV).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}); +} + +void ConvertDynamicBroadcastTo::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}, BuildDynamicBroadcastTo); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.h b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.h index 05cd2b2f62d..a65cffd4216 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.h +++ b/mindspore/ccsrc/backend/common/pass/convert_dynamic_broadcast_to.h @@ -16,18 +16,19 @@ #ifndef MINDSPORE_CONVERT_DYNAMIC_BROADCAST_TO_ATTR_H #define MINDSPORE_CONVERT_DYNAMIC_BROADCAST_TO_ATTR_H -#include -#include "ir/anf.h" -#include "include/common/utils/convert_utils.h" -#include "backend/common/optimizer/optimizer.h" + +#include +#include "backend/common/optimizer/pattern_to_pattern.h" + namespace mindspore { namespace opt { -class ConvertDynamicBroadcastTo : public PatternProcessPass { +class ConvertDynamicBroadcastTo : public PatternToPatternPass { public: - explicit ConvertDynamicBroadcastTo(bool multigraph = true) - : PatternProcessPass("convert_dynamic_broadcast_to", multigraph) {} + ConvertDynamicBroadcastTo() : PatternToPatternPass("convert_dynamic_broadcast_to", true) {} ~ConvertDynamicBroadcastTo() override = default; - const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.cc b/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.cc index bfe751aeff1..7eac64d1ccc 100644 --- a/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.cc +++ b/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.cc @@ -27,6 +27,10 @@ namespace mindspore { namespace opt { namespace { +constexpr auto kXs = "Xs"; +constexpr auto kMCustom = "m_custom"; +constexpr auto kRCustom = "r_custom"; + void ParseAttrDefaultValue(const std::string &op_name, const std::string &attr_name, const std::string &attr_value, const std::string &attr_type, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); @@ -117,30 +121,14 @@ void AddMissingAttrs(const CNodePtr &cnode, kernel::OpImplyType imply_type, cnode->set_input(kAnfPrimitiveIndex, NewValueNode(primitive)); } } -} // namespace -const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const { - if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) { - return nullptr; - } - auto cnode = node->cast(); +AnfNodePtr BuildCustom(const PatternMap &m, const AnfNodePtr &default_node) { + auto cnode = m.Get(kMCustom)->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (!IsPrimitiveCNode(cnode, prim::kPrimCustom)) { - return nullptr; - } - auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode); MS_EXCEPTION_IF_NULL(primitive); auto func_type = common::AnfAlgo::GetNodeAttr(cnode, kAttrFuncType); - // AKG/AICPU need to process attr, TBE will process later in the json creating phase. - if (!IsOneOfCustomAkgType(func_type) || func_type == kCustomTypeAICPU) { - return nullptr; - } - // Early return if current node does not have attr auto attr_names = primitive->GetAttr(kAttrAttrNames); - if (attr_names == nullptr) { - return nullptr; - } // Early return if all attr in reg info exist in the node's attr std::unordered_set missing_attrs; auto attr_names_vec = GetValue>(attr_names); @@ -156,7 +144,34 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN func_type == kCustomTypeAICPU ? kernel::OpImplyType::kImplyAICPU : kernel::OpImplyType::kImplyAKG; AddMissingAttrs(cnode, imply_type, missing_attrs); - return node; + return cnode; +} +} // namespace + +bool CustomOpRegInfoToAttr::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(primitive); + auto func_type = common::AnfAlgo::GetNodeAttr(cnode, kAttrFuncType); + // AKG/AICPU need to process attr, TBE will process later in the json creating phase. + if (!IsOneOfCustomAkgType(func_type) || func_type == kCustomTypeAICPU) { + return false; + } + // Early return if current node does not have attr + auto attr_names = primitive->GetAttr(kAttrAttrNames); + if (attr_names == nullptr) { + return false; + } + return true; +} + +void CustomOpRegInfoToAttr::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddSeqVar(kXs).AddCNode(kMCustom, {prim::kPrimCustom, kXs}); +} + +void CustomOpRegInfoToAttr::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRCustom, {prim::kPrimCustom, kXs}, BuildCustom); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.h b/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.h index debd4317aee..ed722a1e8c0 100644 --- a/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.h +++ b/mindspore/ccsrc/backend/common/pass/custom_op_reg_info_to_attr.h @@ -15,17 +15,18 @@ */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_REG_INFO_TO_ATTR_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_REG_INFO_TO_ATTR_H_ -#include "ir/anf.h" -#include "backend/common/optimizer/optimizer.h" +#include +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class CustomOpRegInfoToAttr : public PatternProcessPass { +class CustomOpRegInfoToAttr : public PatternToPatternPass { public: - explicit CustomOpRegInfoToAttr(bool multigraph = true) - : PatternProcessPass("custom_op_reg_info_to_attr", multigraph) {} + CustomOpRegInfoToAttr() : PatternToPatternPass("custom_op_reg_info_to_attr", true) {} ~CustomOpRegInfoToAttr() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.cc b/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.cc index a5191d93f12..263c396222d 100644 --- a/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.cc +++ b/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "include/common/utils/anfalgo.h" #include "backend/common/session/anf_runtime_algorithm.h" @@ -138,6 +139,12 @@ void ExpandFlattenConcatTupleInput(const FuncGraphPtr &graph, const CNodePtr &cn } } // namespace +std::vector FlattenConcatFission::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimFlattenConcat->name()); + return ret; +} + const BaseRef FlattenConcatFission::DefinePattern() const { VarPtr Xs = std::make_shared(); return VectorRef({prim::kPrimFlattenConcat, Xs}); diff --git a/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.h b/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.h index bf7d808c420..7ffce5190c1 100644 --- a/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.h +++ b/mindspore/ccsrc/backend/common/pass/flatten_concat_fission.h @@ -16,6 +16,8 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FLATTEN_CONCAT_FISSION_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FLATTEN_CONCAT_FISSION_H_ +#include +#include #include "backend/common/optimizer/optimizer.h" namespace mindspore { @@ -26,6 +28,9 @@ class FlattenConcatFission : public PatternProcessPass { ~FlattenConcatFission() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.cc b/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.cc index abca22a8800..c71588a362d 100644 --- a/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.cc +++ b/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.cc @@ -100,6 +100,12 @@ CNodePtr InplaceAssignAfterTupleGetItem(const FuncGraphPtr &func_graph, const CN return nullptr; } +std::vector InplaceAssignForCustomOp::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimCustom->name()); + return ret; +} + const AnfNodePtr InplaceAssignForCustomOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); diff --git a/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.h b/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.h index 73eb777c9eb..333ac27e903 100644 --- a/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.h +++ b/mindspore/ccsrc/backend/common/pass/inplace_assign_for_custom_op.h @@ -15,6 +15,10 @@ */ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INPLACE_ASSIGN_FOR_CUSTOM_OP_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INPLACE_ASSIGN_FOR_CUSTOM_OP_H_ + +#include +#include + #include "ir/anf.h" #include "backend/common/optimizer/optimizer.h" @@ -26,6 +30,7 @@ class InplaceAssignForCustomOp : public PatternProcessPass { : PatternProcessPass("inplace_assign_for_custom_op", multigraph) {} ~InplaceAssignForCustomOp() override = default; const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; + std::vector MustExistPrimitiveName() const override; private: mutable mindspore::HashSet visited_{}; diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/add_attr_for_3d_graph.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/add_attr_for_3d_graph.cc index 17919b6a2cc..2a7210f57e5 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/add_attr_for_3d_graph.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/enhancer/add_attr_for_3d_graph.cc @@ -19,10 +19,10 @@ namespace mindspore { namespace opt { namespace { -constexpr auto m_3d = "m_3d"; -constexpr auto V = "V"; -constexpr auto Xs = "Xs"; -constexpr auto r_3d = "r_3d"; +constexpr auto kM3d = "m_3d"; +constexpr auto kV = "V"; +constexpr auto kXs = "Xs"; +constexpr auto kR3d = "r_3d"; } // namespace bool AddIoFormatAttrFor3DGraph::CheckMatchedDAG(const PatternMap &m, const FuncGraphPtr &graph, @@ -36,7 +36,7 @@ bool AddIoFormatAttrFor3DGraph::CheckMatchedDAG(const PatternMap &m, const FuncG } AnfNodePtr AddAttr(const PatternMap &m, const AnfNodePtr & /* default_cnode */) { - auto node = m.Get(m_3d); + auto node = m.Get(kM3d); common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); auto formats = AnfAlgo::GetAllOutputFormats(node); if (std::any_of(formats.begin(), formats.end(), [](const std::string &format) { return IsOneOf3DFormat(format); })) { @@ -45,10 +45,10 @@ AnfNodePtr AddAttr(const PatternMap &m, const AnfNodePtr & /* default_cnode */) return node; } void AddIoFormatAttrFor3DGraph::DefineSrcPattern(SrcPattern *src_pattern) { - (void)(*src_pattern).AddVar(V, UnVisited).AddSeqVar(Xs).AddCNode(m_3d, {V, Xs}); + (void)(*src_pattern).AddVar(kV, UnVisited).AddSeqVar(kXs).AddCNode(kM3d, {kV, kXs}); } void AddIoFormatAttrFor3DGraph::DefineDstPattern(DstPattern *dst_pattern) { - (void)(*dst_pattern).AddCNode(r_3d, {V, Xs}, AddAttr); + (void)(*dst_pattern).AddCNode(kR3d, {kV, kXs}, AddAttr); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.cc index 1eed9bb954f..a2da26cf8b1 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/format_type/reselect_call_inline_format.cc @@ -20,7 +20,7 @@ namespace mindspore { namespace opt { namespace { -constexpr auto Xs = "Xs"; +constexpr auto kXs = "Xs"; constexpr auto call_inline = "call_inline"; constexpr auto new_call_inline = "new_call_inline"; } // namespace @@ -41,11 +41,11 @@ AnfNodePtr BuildCallInline(const PatternMap &m, const AnfNodePtr &) { } void ReselectCallInlineFormat::DefineSrcPattern(SrcPattern *src_pattern) { - (*src_pattern).AddSeqVar(Xs).AddCNode(call_inline, {prim::kPrimCallInline, Xs}); + (*src_pattern).AddSeqVar(kXs).AddCNode(call_inline, {prim::kPrimCallInline, kXs}); } void ReselectCallInlineFormat::DefineDstPattern(DstPattern *dst_pattern) { - (*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, Xs}, BuildCallInline); + (*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, kXs}, BuildCallInline); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/reduce_axis_update.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/reduce_axis_update.cc index ef3c0b5bd8f..ea80eefac75 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/reduce_axis_update.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ge/reduce_axis_update.cc @@ -29,8 +29,8 @@ constexpr size_t kReduceInputNum = 2; constexpr size_t kAxisInputIndex = 2; constexpr auto r_reduce = "r_reduce"; constexpr auto m_reduce = "m_reduce"; -constexpr auto Xs = "Xs"; -constexpr auto V = "V"; +constexpr auto kXs = "Xs"; +constexpr auto kV = "V"; constexpr auto v_axis = "axis"; } // namespace @@ -128,13 +128,13 @@ AnfNodePtr BuildReduce(const PatternMap &m, const AnfNodePtr &) { } void ReduceAxisUpdate::DefineSrcPattern(SrcPattern *src_pattern) { - (void)(*src_pattern).AddVar(V, IsReduce).AddSeqVar(Xs).AddCNode(m_reduce, {V, Xs}); + (void)(*src_pattern).AddVar(kV, IsReduce).AddSeqVar(kXs).AddCNode(m_reduce, {kV, kXs}); } void ReduceAxisUpdate::DefineDstPattern(DstPattern *dst_pattern) { - auto reduce_input = Unpacking(Xs); + auto reduce_input = Unpacking(kXs); reduce_input[kAxisInputIndex - 1] = v_axis; - (void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {V, reduce_input}, BuildReduce); + (void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {kV, reduce_input}, BuildReduce); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/reduce_min_fission.cc index ff40dbde58b..f2ebd07b30d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/reduce_min_fission.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/reduce_min_fission.cc @@ -21,10 +21,10 @@ namespace mindspore { namespace opt { namespace { -constexpr auto m_reduce_min = "m_reduce_min"; -constexpr auto r_reduce_min1 = "r_reduce_min1"; -constexpr auto r_reduce_min2 = "r_reduce_min2"; -constexpr auto X = "X"; +constexpr auto kMReduceMin = "m_reduce_min"; +constexpr auto kRReduceMin1 = "r_reduce_min1"; +constexpr auto kRReduceMin2 = "r_reduce_min2"; +constexpr auto kX1 = "X1"; bool NeedOptimize(const TypeId &dtype, const ShapeVector &shape, const std::vector &axis) { if (dtype != kNumberTypeFloat32) { @@ -129,7 +129,7 @@ bool ReduceMinFission::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &g } AnfNodePtr BuildReduceMin1(const PatternMap &m, const AnfNodePtr &default_node) { - auto cnode = m.Get(m_reduce_min)->cast(); + auto cnode = m.Get(kMReduceMin)->cast(); CNodePtr reduce_min1 = InitReduceMin(default_node->cast(), cnode); auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); auto dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); @@ -143,7 +143,7 @@ AnfNodePtr BuildReduceMin1(const PatternMap &m, const AnfNodePtr &default_node) } AnfNodePtr BuildReduceMin2(const PatternMap &m, const AnfNodePtr &default_node) { - auto cnode = m.Get(m_reduce_min)->cast(); + auto cnode = m.Get(kMReduceMin)->cast(); CNodePtr reduce_min2 = InitReduceMin(default_node->cast(), cnode); reduce_min2->set_abstract(cnode->abstract()); std::vector axis_last = {-1}; @@ -152,13 +152,13 @@ AnfNodePtr BuildReduceMin2(const PatternMap &m, const AnfNodePtr &default_node) } void ReduceMinFission::DefineSrcPattern(SrcPattern *src_pattern) { - (void)(*src_pattern).AddVar(X).AddCNode(m_reduce_min, {prim::kPrimReduceMinD, X}); + (void)(*src_pattern).AddVar(kX1).AddCNode(kMReduceMin, {prim::kPrimReduceMinD, kX1}); } void ReduceMinFission::DefineDstPattern(DstPattern *dst_pattern) { (void)(*dst_pattern) - .AddCNode(r_reduce_min1, {prim::kPrimReduceMinD, X}, BuildReduceMin1) - .AddCNode(r_reduce_min2, {prim::kPrimReduceMinD, r_reduce_min1}, BuildReduceMin2); + .AddCNode(kRReduceMin1, {prim::kPrimReduceMinD, kX1}, BuildReduceMin1) + .AddCNode(kRReduceMin2, {prim::kPrimReduceMinD, kRReduceMin1}, BuildReduceMin2); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc index 9eb58dc963a..61040111b1a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.cc @@ -24,12 +24,7 @@ namespace mindspore { namespace opt { -const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(equiv); - +bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const { static const std::set kAICpuOpNames = {kDropoutGenMaskOpName, kEnvironCreateOpName, kEnvironSetOpName, @@ -262,7 +257,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An static const std::string kCpuKernelSoName = "mindspore_cpu_kernels"; if (!node->isa()) { - return node; + return false; } auto kernel_name = common::AnfAlgo::GetCNodeName(node); if (kAICpuOpNames.find(kernel_name) != kAICpuOpNames.end()) { @@ -272,7 +267,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An common::AnfAlgo::SetNodeAttr(kAttrCustAicpu, MakeValue(kCpuKernelSoName), node); } - return node; + return true; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h index 67620444ba2..788ebaf9230 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/aicpu_lib_select.h @@ -16,15 +16,14 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_ -#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/inplace_node_pass.h" namespace mindspore { namespace opt { -class AICpuLibSelectPass : public PatternProcessPass { +class AICpuLibSelectPass : public InplaceNodePass { public: - explicit AICpuLibSelectPass(bool multigraph = true) : PatternProcessPass("env_op_attr_update", multigraph) {} - ~AICpuLibSelectPass() override = default; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + AICpuLibSelectPass() : InplaceNodePass("env_op_attr_update") {} + bool Process(const AnfNodePtr &node) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc index c906386be23..527e13e0b0b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.cc @@ -181,6 +181,12 @@ CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const return concat; } +std::vector NeighborExchangeUnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimNeighborExchange->name()); + return ret; +} + const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const { return VectorRef({prim::kPrimNeighborExchange, std::make_shared()}); } @@ -193,6 +199,12 @@ const AnfNodePtr NeighborExchangeUnifyMindIR::Process(const FuncGraphPtr &graph, return node; } +std::vector AllToAllUnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimAllToAll->name()); + return ret; +} + const BaseRef AllToAllUnifyMindIR::DefinePattern() const { return VectorRef({prim::kPrimAllToAll, std::make_shared()}); } diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h index ec8effd4fb3..d55cf5da491 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/all_to_all_unify_mindir.h @@ -17,6 +17,8 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_ #include +#include +#include #include "backend/common/optimizer/optimizer.h" namespace mindspore { @@ -28,6 +30,9 @@ class NeighborExchangeUnifyMindIR : public PatternProcessPass { ~NeighborExchangeUnifyMindIR() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; class AllToAllUnifyMindIR : public PatternProcessPass { @@ -41,6 +46,7 @@ class AllToAllUnifyMindIR : public PatternProcessPass { CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const; CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const; CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) const; + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.cc index 4b70d14be41..fbd893eb099 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.cc @@ -38,6 +38,15 @@ constexpr size_t kAvgPoolGradInputNum = 3; constexpr size_t kShapeDimNum = 4; constexpr float kKernelMatrixInitNum = 1.0; constexpr size_t kFloat32Len = 4; // size of float32 +constexpr auto kX1 = "X1"; +constexpr auto kX2 = "X2"; +constexpr auto kG = "G"; +constexpr auto kXShapeVNode = "XShapeVNode"; +constexpr auto kMeanMatrixVNode = "MeanMatrixVNode"; +constexpr auto kKernelMatrixVNode = "KernelMatrixVNode"; +constexpr auto kMAvgPoolGrad = "m_avg_pool_grad"; +constexpr auto kRAvgPoolGrad = "r_avg_pool_grad"; + std::vector GetInputXShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0UL); @@ -181,40 +190,63 @@ ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const A kernel_graph->AddValueNodeToGraph(kernel_matrix_vnode); return kernel_matrix_vnode; } -} // namespace +class BuildXShapeVNode { + public: + BuildXShapeVNode() = default; + AnfNodePtr operator()(const PatternMap &m) const { + auto node = m.Get(kMAvgPoolGrad); + MS_EXCEPTION_IF_NULL(node); + auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); + auto x_shape = GetInputXShape(avgpool_grad); + auto graph = avgpool_grad->func_graph(); + auto x_shape_vnode = CreateShapeValueNode(graph, x_shape); + return x_shape_vnode; + } +}; +class BuildMeanMatrixVNode { + public: + BuildMeanMatrixVNode() = default; + AnfNodePtr operator()(const PatternMap &m) const { + auto node = m.Get(kMAvgPoolGrad); + MS_EXCEPTION_IF_NULL(node); + auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); + auto x_shape = GetInputXShape(avgpool_grad); + auto k_size = common::AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrKernelSize); + auto stride = common::AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrStrides); + auto prim = GetCNodePrimitive(avgpool_grad); + MS_EXCEPTION_IF_NULL(prim); + int64_t pad_mode_value = 0; + CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr(kAttrPadMode), &pad_mode_value, true); + auto pad_mode = PadMode(pad_mode_value); + auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL); -const BaseRef AvgPoolGradUnifyMindIR::DefinePattern() const { - VarPtr X1 = std::make_shared(); - VarPtr X2 = std::make_shared(); - VarPtr G = std::make_shared(); - VectorRef pattern({prim::kPrimAvgPoolGrad, X1, X2, G}); - return pattern; -} - -const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); + auto graph = avgpool_grad->func_graph(); + auto mean_matrix_vnode = CreateMeanMatrixValueNode(graph, node, x_shape, k_size, stride, pad_mode, x_dtype); + return mean_matrix_vnode; + } +}; +class BuildKernelMatrixVNode { + public: + BuildKernelMatrixVNode() = default; + AnfNodePtr operator()(const PatternMap &m) const { + auto node = m.Get(kMAvgPoolGrad); + MS_EXCEPTION_IF_NULL(node); + auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); + auto k_size = common::AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrKernelSize); + auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL); + auto x_shape = GetInputXShape(avgpool_grad); + auto graph = avgpool_grad->func_graph(); + auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, node, x_shape, k_size, x_dtype); + return kernel_matrix_vnode; + } +}; +AnfNodePtr BuildAvgPoolGrad(const PatternMap &m, const AnfNodePtr &new_node) { + auto node = m.Get(kMAvgPoolGrad); MS_EXCEPTION_IF_NULL(node); auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum); - auto x_shape = GetInputXShape(avgpool_grad); - auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL); - auto k_size = common::AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrKernelSize); - auto stride = common::AnfAlgo::GetNodeAttr>(avgpool_grad, kAttrStrides); - auto prim = GetCNodePrimitive(avgpool_grad); - MS_EXCEPTION_IF_NULL(prim); - int64_t pad_mode_value = 0; - CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr(kAttrPadMode), &pad_mode_value, true); - auto pad_mode = PadMode(pad_mode_value); - - auto x_shape_vnode = CreateShapeValueNode(graph, x_shape); - auto mean_matrix_vnode = CreateMeanMatrixValueNode(graph, node, x_shape, k_size, stride, pad_mode, x_dtype); - auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, node, x_shape, k_size, x_dtype); - - std::vector avgpool_grad_vm_inputs = {NewValueNode(std::make_shared(kAvgPoolGradOpName)), - x_shape_vnode, avgpool_grad->input(3UL), mean_matrix_vnode, - kernel_matrix_vnode}; - auto avgpool_grad_vm = NewCNode(avgpool_grad_vm_inputs, graph); + MS_EXCEPTION_IF_NULL(new_node); + auto avgpool_grad_vm = new_node->cast(); MS_EXCEPTION_IF_NULL(avgpool_grad_vm); avgpool_grad_vm->set_scope(avgpool_grad->scope()); avgpool_grad_vm->set_abstract(avgpool_grad->abstract()); @@ -228,5 +260,20 @@ const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, cons common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), avgpool_grad_vm); return avgpool_grad_vm; } +} // namespace + +void AvgPoolGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kX1).AddVar(kX2).AddVar(kG).AddCNode(kMAvgPoolGrad, {prim::kPrimAvgPoolGrad, kX1, kX2, kG}); +} + +void AvgPoolGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kXShapeVNode, BuildXShapeVNode()) + .AddValueNode(kMeanMatrixVNode, BuildMeanMatrixVNode()) + .AddValueNode(kKernelMatrixVNode, BuildKernelMatrixVNode()) + .AddCNode(kRAvgPoolGrad, + {std::make_shared(kAvgPoolGradOpName), kXShapeVNode, kG, kMeanMatrixVNode, kKernelMatrixVNode}, + BuildAvgPoolGrad); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.h index e6f376a26ae..6f126f95560 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/avg_pool_grad_unify_mindir.h @@ -16,18 +16,18 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_ -#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class AvgPoolGradUnifyMindIR : public PatternProcessPass { +class AvgPoolGradUnifyMindIR : public PatternToPatternPass { public: - explicit AvgPoolGradUnifyMindIR(bool multigraph = true) - : PatternProcessPass("avg_pool_grad_unify_mindir", multigraph) {} + AvgPoolGradUnifyMindIR() : PatternToPatternPass("avg_pool_grad_unify_mindir", true) {} ~AvgPoolGradUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.cc index 75e5c52e042..d990a69af9f 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.cc @@ -23,22 +23,25 @@ namespace mindspore { namespace opt { namespace { constexpr auto kAttrUnifyIRPassed = "unifyir_passed"; +constexpr auto kX1 = "X1"; +constexpr auto kX2 = "X2"; +constexpr auto kX3 = "X3"; +constexpr auto kX4 = "X4"; +constexpr auto kX5 = "X5"; +constexpr auto kXs = "Xs"; +constexpr auto kMBatchnormGrad = "m_batchnorm_grad"; +constexpr auto kRBatchnormGrad = "r_batchnorm_grad"; } // namespace -AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr &graph, - const CNodePtr &bn_grad_node) const { - MS_EXCEPTION_IF_NULL(graph); +AnfNodePtr BuildBatchNormGrad(const PatternMap &m, const AnfNodePtr &new_node) { + auto node = m.Get(kMBatchnormGrad); + MS_EXCEPTION_IF_NULL(node); + auto bn_grad_node = node->cast(); MS_EXCEPTION_IF_NULL(bn_grad_node); size_t kBNGradInputNum = 6; - const auto &bn_grad_node_inputs = bn_grad_node->inputs(); CheckCNodeInputSize(bn_grad_node, kBNGradInputNum); - std::vector bn_grad_inputs = {NewValueNode(std::make_shared(kBatchNormGradOpName)), - bn_grad_node_inputs[kDim1], - bn_grad_node_inputs[kDim2], - bn_grad_node_inputs[kDim3], - bn_grad_node_inputs[kDim4], - bn_grad_node_inputs[kDim5]}; - auto new_bn_grad = NewCNode(bn_grad_inputs, graph); + auto new_bn_grad = new_node->cast(); + MS_EXCEPTION_IF_NULL(new_bn_grad); MS_EXCEPTION_IF_NULL(new_bn_grad); new_bn_grad->set_scope(bn_grad_node->scope()); auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0UL), @@ -57,24 +60,33 @@ AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr & return new_bn_grad; } -const BaseRef BatchNormGradUnifyMindIR::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto prim = std::make_shared(kBatchNormGradOpName); - return VectorRef({prim, Xs}); -} - -const AnfNodePtr BatchNormGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { +bool BatchNormGradUnifyMindIR::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &func_graph, + const AnfNodePtr &node) const { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(func_graph); - auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode) || (func_graph->has_flag(kAttrMutableKernel) && !GetBoolAttr(cnode, kAttrIsTraining))) { - return nullptr; + return false; } - return CreateNewBatchNormGrad(func_graph, cnode); + return true; +} + +void BatchNormGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kX1) + .AddVar(kX2) + .AddVar(kX3) + .AddVar(kX4) + .AddVar(kX5) + .AddSeqVar(kXs) + .AddCNode(kMBatchnormGrad, {std::make_shared(kBatchNormGradOpName), kX1, kX2, kX3, kX4, kX5, kXs}); +} + +void BatchNormGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddCNode(kRBatchnormGrad, {std::make_shared(kBatchNormGradOpName), kX1, kX2, kX3, kX4, kX5}, + BuildBatchNormGrad); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h index d036eb80ce7..21265b78ad0 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/bn_grad_unify_mindir.h @@ -17,19 +17,18 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_BN_GRAD_UNIFY_MINDIR_H_ #include "backend/common/optimizer/optimizer.h" -#include "backend/common/optimizer/helper.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class BatchNormGradUnifyMindIR : public PatternProcessPass { +class BatchNormGradUnifyMindIR : public PatternToPatternPass { public: - explicit BatchNormGradUnifyMindIR(bool multigraph = true) : PatternProcessPass("bn_grad_unify_mindir", multigraph) {} + BatchNormGradUnifyMindIR() : PatternToPatternPass("bn_grad_unify_mindir", true) {} ~BatchNormGradUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; - private: - AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) const; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc index ffd5fd3b1ac..09d7c8a297a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.cc @@ -46,6 +46,11 @@ constexpr int64_t kV3ShapeLimitSize = 1 << 30; constexpr size_t kDropoutGradInputTensorNum = 2; constexpr size_t kFloat16Len = 2; // size of float16 constexpr size_t kInt64Len = 8; // size of int64 +constexpr auto kX1 = "X1"; +constexpr auto kX2 = "X2"; +constexpr auto kKeepProbValue = "KeepProbValue"; +constexpr auto kMDropoutGrad = "m_dropout_grad"; +constexpr auto kRDropoutDoMask = "r_dropout_do_mask"; TypeId GetInputXDataType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); @@ -303,6 +308,92 @@ void UpdateReturnNode(const FuncGraphPtr &graph, const AnfNodePtr &origin_node, g_output->set_abstract(abstract); graph->set_output(g_output); } + +class BuildKeepProbValue { + public: + BuildKeepProbValue() = default; + AnfNodePtr operator()(const PatternMap &m) const { + auto node = m.Get(kMDropoutGrad); + MS_EXCEPTION_IF_NULL(node); + auto dropout_grad_cnode = node->cast(); + CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum); + + auto func_graph = node->func_graph(); + auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); + return keep_prob_value; + } +}; + +AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) { + auto node = m.Get(kMDropoutGrad); + MS_EXCEPTION_IF_NULL(node); + auto dropout_grad_cnode = node->cast(); + CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum); + + auto func_graph = dropout_grad_cnode->func_graph(); + auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode); + auto grad_input_shape = GetDropoutInputShape(dropout_grad_cnode->input(kIndex1)); + auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); + auto use_v3 = WhetherUseDropoutV3(dropout_grad_cnode, grad_input_shape); + + // DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter + // in that scene, need to be updated. + auto mask_input = dropout_grad_cnode->input(kIndex2); + MS_EXCEPTION_IF_NULL(mask_input); + if (mask_input->isa()) { + // update abstract + auto mask_abstract = mask_input->abstract(); + MS_EXCEPTION_IF_NULL(mask_abstract); + auto grad_shape_vec = grad_input_shape->shape(); + auto mask_shape = + use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); + mask_abstract = std::make_shared(kUInt8, mask_shape); + mask_input->set_abstract(mask_abstract); + // update kernel info + auto kernel_build_info_builder = std::make_shared(); + kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{kNumberTypeUInt8}); + kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get()); + } else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) { + auto mask_input_cnode = mask_input->cast(); + MS_EXCEPTION_IF_NULL(mask_input_cnode); + auto tuple_input = mask_input_cnode->input(kIndex1); + MS_EXCEPTION_IF_NULL(tuple_input); + if (IsValueNode(tuple_input)) { + auto tuple_abstract = tuple_input->abstract(); + MS_EXCEPTION_IF_NULL(tuple_abstract); + abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast(); + MS_EXCEPTION_IF_NULL(sequence_abstract_ptr); + // Dropout's outputs only have two elements. + if (sequence_abstract_ptr->size() != kIndex2) { + MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size(); + } + abstract::AbstractBasePtrList abs{}; + abs.push_back(sequence_abstract_ptr->elements()[0]); + // modify mask abstract + auto mask_abstract = mask_input->abstract(); + MS_EXCEPTION_IF_NULL(mask_abstract); + auto grad_shape_vec = grad_input_shape->shape(); + auto mask_shape = + use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); + mask_abstract = std::make_shared(kUInt8, mask_shape); + mask_input->set_abstract(mask_abstract); + abs.push_back(mask_abstract); + auto new_abstract = std::make_shared(abs); + tuple_input->set_abstract(new_abstract); + } + } + + // CreateDropoutDoMask + auto do_mask_abstract = + std::make_shared(TypeIdToType(grad_input_type_id), grad_input_shape); + auto dropout_do_mask = CreateDropoutDoMaskCNode(func_graph, dropout_grad_cnode, + {dropout_grad_cnode->input(kIndex1), mask_input, keep_prob_value}, + do_mask_abstract, use_v3); + return dropout_do_mask; +} } // namespace const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const { @@ -433,6 +524,12 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co return tuple_cnode; } +std::vector DropoutUnifyMindIR1::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimDropout->name()); + return ret; +} + const BaseRef DropoutUnifyMindIR1::DefinePattern() const { VarPtr X = std::make_shared(); return VectorRef({prim::kPrimDropout, X}); @@ -477,80 +574,17 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co return make_tuple; } -const BaseRef DropoutGradUnifyMindIR::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - auto dropout_grad_prim = std::make_shared(kDropoutGradOpName); - return VectorRef({dropout_grad_prim, X, Y}); +void DropoutGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kX1) + .AddVar(kX2) + .AddCNode(kMDropoutGrad, {std::make_shared(kDropoutGradOpName), kX1, kX2}); } -const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto dropout_grad_cnode = node->cast(); - CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum); - - auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode); - auto grad_input_shape = GetDropoutInputShape(dropout_grad_cnode->input(kIndex1)); - auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id); - auto use_v3 = WhetherUseDropoutV3(dropout_grad_cnode, grad_input_shape); - - // DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter - // in that scene, need to be updated. - auto mask_input = dropout_grad_cnode->input(kIndex2); - MS_EXCEPTION_IF_NULL(mask_input); - if (mask_input->isa()) { - // update abstract - auto mask_abstract = mask_input->abstract(); - MS_EXCEPTION_IF_NULL(mask_abstract); - auto grad_shape_vec = grad_input_shape->shape(); - auto mask_shape = - use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); - mask_abstract = std::make_shared(kUInt8, mask_shape); - mask_input->set_abstract(mask_abstract); - // update kernel info - auto kernel_build_info_builder = std::make_shared(); - kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{kNumberTypeUInt8}); - kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR}); - AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get()); - } else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) { - auto mask_input_cnode = mask_input->cast(); - MS_EXCEPTION_IF_NULL(mask_input_cnode); - auto tuple_input = mask_input_cnode->input(kIndex1); - MS_EXCEPTION_IF_NULL(tuple_input); - if (IsValueNode(tuple_input)) { - auto tuple_abstract = tuple_input->abstract(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast(); - MS_EXCEPTION_IF_NULL(sequence_abstract_ptr); - // Dropout's outputs only have two elements. - if (sequence_abstract_ptr->size() != kIndex2) { - MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size(); - } - abstract::AbstractBasePtrList abs{}; - abs.push_back(sequence_abstract_ptr->elements()[0]); - // modify mask abstract - auto mask_abstract = mask_input->abstract(); - MS_EXCEPTION_IF_NULL(mask_abstract); - auto grad_shape_vec = grad_input_shape->shape(); - auto mask_shape = - use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec); - mask_abstract = std::make_shared(kUInt8, mask_shape); - mask_input->set_abstract(mask_abstract); - abs.push_back(mask_abstract); - auto new_abstract = std::make_shared(abs); - tuple_input->set_abstract(new_abstract); - } - } - - // CreateDropoutDoMask - auto do_mask_abstract = - std::make_shared(TypeIdToType(grad_input_type_id), grad_input_shape); - auto dropout_do_mask = CreateDropoutDoMaskCNode(func_graph, dropout_grad_cnode, - {dropout_grad_cnode->input(kIndex1), mask_input, keep_prob_value}, - do_mask_abstract, use_v3); - return dropout_do_mask; +void DropoutGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kKeepProbValue, BuildKeepProbValue()) + .AddCNode(kRDropoutDoMask, {std::make_shared(kDropoutDoMaskOpName), kX1, kX2, kKeepProbValue}, + BuildDropoutDoMask); } } // namespace mindspore::opt diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.h index f4881ebb31c..91ff65b308d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/dropout_unify_mindir.h @@ -17,7 +17,10 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_ #include +#include +#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { @@ -49,15 +52,18 @@ class DropoutUnifyMindIR1 : public PatternProcessPass { ~DropoutUnifyMindIR1() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; -class DropoutGradUnifyMindIR : public PatternProcessPass { +class DropoutGradUnifyMindIR : public PatternToPatternPass { public: - explicit DropoutGradUnifyMindIR(bool multigraph = true) - : PatternProcessPass("dropoutgrad_unify_mindir", multigraph) {} + DropoutGradUnifyMindIR() : PatternToPatternPass("dropoutgrad_unify_mindir", true) {} ~DropoutGradUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.cc index aa24266a849..27865498009 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.cc @@ -17,6 +17,7 @@ #include #include +#include #include "include/common/utils/utils.h" #include "utils/ms_context.h" @@ -27,6 +28,12 @@ namespace mindspore { namespace opt { +std::vector FSEDecodeAdjust::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(std::make_shared(kFSEDecodeOpName)->name()); + return ret; +} + const BaseRef FSEDecodeAdjust::DefinePattern() const { VarPtr Xs = std::make_shared(); auto prim = std::make_shared(kFSEDecodeOpName); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h index a67a526231c..f8170939034 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/fse_decode_adjust.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FSE_DECODE_ADJUST_H_ #include +#include #include "backend/common/optimizer/optimizer.h" #include "backend/common/optimizer/helper.h" @@ -28,6 +29,9 @@ class FSEDecodeAdjust : public PatternProcessPass { ~FSEDecodeAdjust() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.cc index 3eb00d2d12b..e0c1b7fdacc 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.cc @@ -18,6 +18,7 @@ #include #include +#include #include "include/common/utils/utils.h" #include "utils/ms_context.h" @@ -120,6 +121,13 @@ void MaxPool2MaxPoolWithArgmax::SetNodeAttrs(const CNodePtr &maxpool, const CNod common::AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_grad_argmax); } +std::vector MaxPool2MaxPoolWithArgmax::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimMaxPool->name()); + ret.emplace_back(prim::kPrimMaxPoolGrad->name()); + return ret; +} + const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const { VarPtr X = std::make_shared(); VarPtr Y = std::make_shared(); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.h index 0c9b94067ee..3a95c9a0d49 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_to_maxpool_with_argmax.h @@ -18,6 +18,7 @@ #include #include +#include #include "backend/common/optimizer/optimizer.h" namespace mindspore { @@ -36,6 +37,7 @@ class MaxPool2MaxPoolWithArgmax : public PatternProcessPass { const std::vector &maxpool_argmax_outputs) const; void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax, const CNodePtr &maxpool_grad_argmax) const; + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.cc index abbde1b4d10..760e694ea0d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.cc @@ -32,6 +32,15 @@ constexpr size_t kMaxPoolGradWithArgmaxInputTensorNum = 3; constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4; constexpr size_t kMaxPoolWithArgmaxShape = 4; constexpr size_t kAlignBytes = 16; +constexpr auto kX1 = "X1"; +constexpr auto kX2 = "X2"; +constexpr auto kMaxPoolIndex = "index0"; +constexpr auto kMMaxPool = "m_max_pool"; +constexpr auto kRMaxPool = "r_max_pool"; +constexpr auto kMMaxpoolWithArgmax = "m_maxpool_with_argmax"; +constexpr auto kMTupleGetitem0 = "m_tuple_getitem0"; +constexpr auto kMMaxpoolGradWithArgmax = "m_maxpool_grad_with_argmax"; +constexpr auto kRMaxpoolGradWithArgmax = "r_maxpool_grad_with_argmax"; bool IsC(const BaseRef &n) { if (utils::isa(n)) { @@ -48,17 +57,9 @@ CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) { MS_EXCEPTION_IF_NULL(tuple_getitem0_anf); return tuple_getitem0_anf->cast(); } -} // namespace -const BaseRef MaxPoolWithArgmaxUnifyMindIR::DefinePattern() const { - VarPtr X = std::make_shared(); - VectorRef pattern({prim::kPrimMaxPoolWithArgmax, X}); - return pattern; -} - -const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); +AnfNodePtr BuildMaxPoolWithArgmax(const PatternMap &m, const AnfNodePtr &) { + auto node = m.Get(kMMaxPool); MS_EXCEPTION_IF_NULL(node); auto maxpool_with_argmax = node->cast(); MS_EXCEPTION_IF_NULL(maxpool_with_argmax); @@ -85,19 +86,8 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph return maxpool_with_argmax; } -const BaseRef MaxPoolGradWithArgmaxUnifyMindIR::DefinePattern() const { - VarPtr X = std::make_shared(); - VarPtr Y = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VectorRef maxpool_with_argmax({prim::kPrimMaxPoolWithArgmax, X}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, maxpool_with_argmax, index0}); - VectorRef maxpool_grad_with_argmax({prim::kPrimMaxPoolGradWithArgmax, X, Y, tuple_getitem0}); - return maxpool_grad_with_argmax; -} - -const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); +AnfNodePtr BuildMaxPoolGradWithArgmax(const PatternMap &m, const AnfNodePtr &) { + auto node = m.Get(kMMaxpoolGradWithArgmax); MS_EXCEPTION_IF_NULL(node); auto maxpool_grad_with_argmax = node->cast(); MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax); @@ -122,5 +112,29 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g return maxpool_grad_with_argmax; } +} // namespace + +void MaxPoolWithArgmaxUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kX1).AddCNode(kMMaxPool, {prim::kPrimMaxPoolWithArgmax, kX1}); +} + +void MaxPoolWithArgmaxUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRMaxPool, {prim::kPrimMaxPoolWithArgmax, kX1}, BuildMaxPoolWithArgmax); +} + +void MaxPoolGradWithArgmaxUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kX1) + .AddVar(kX2) + .AddVar(kMaxPoolIndex, IsC) + .AddCNode(kMMaxpoolWithArgmax, {prim::kPrimMaxPoolWithArgmax, kX1}) + .AddCNode(kMTupleGetitem0, {prim::kPrimTupleGetItem, kMMaxpoolWithArgmax, kMaxPoolIndex}) + .AddCNode(kMMaxpoolGradWithArgmax, {prim::kPrimMaxPoolGradWithArgmax, kX1, kX2, kMTupleGetitem0}); +} +void MaxPoolGradWithArgmaxUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddCNode(kRMaxpoolGradWithArgmax, {prim::kPrimMaxPoolGradWithArgmax, kX1, kX2, kMTupleGetitem0}, + BuildMaxPoolGradWithArgmax); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.h index 0ce8e97f711..9ab08089f27 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/maxpool_with_argmax_unify_mindir.h @@ -16,27 +16,27 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_ -#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class MaxPoolWithArgmaxUnifyMindIR : public PatternProcessPass { +class MaxPoolWithArgmaxUnifyMindIR : public PatternToPatternPass { public: - explicit MaxPoolWithArgmaxUnifyMindIR(bool multigraph = true) - : PatternProcessPass("maxpool_with_argmax_unify_mindir", multigraph) {} + MaxPoolWithArgmaxUnifyMindIR() : PatternToPatternPass("maxpool_with_argmax_unify_mindir", true) {} ~MaxPoolWithArgmaxUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; }; -class MaxPoolGradWithArgmaxUnifyMindIR : public PatternProcessPass { +class MaxPoolGradWithArgmaxUnifyMindIR : public PatternToPatternPass { public: - explicit MaxPoolGradWithArgmaxUnifyMindIR(bool multigraph = true) - : PatternProcessPass("maxpool_grad_with_argmax_unify_mindir", multigraph) {} + MaxPoolGradWithArgmaxUnifyMindIR() : PatternToPatternPass("maxpool_grad_with_argmax_unify_mindir", true) {} ~MaxPoolGradWithArgmaxUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.cc index 670488afc6e..71bb16490f1 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.cc @@ -912,6 +912,12 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph return addn; } +std::vector NeighborExchangeV2UnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimNeighborExchangeV2->name()); + return ret; +} + const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const { return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared()}); } @@ -929,9 +935,16 @@ const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &grap return concat; } +std::vector NeighborExchangeV2GradUnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimNeighborExchangeV2Grad->name()); + return ret; +} + const BaseRef NeighborExchangeV2GradUnifyMindIR::DefinePattern() const { return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared()}); } + const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(graph); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h index befc6751ecd..66942162b37 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/neighbor_exchange_v2_unify_mindir.h @@ -19,6 +19,8 @@ #include #include #include +#include + #include "backend/common/optimizer/optimizer.h" #include "backend/common/session/anf_runtime_algorithm.h" #include "include/common/utils/anfalgo.h" @@ -49,6 +51,7 @@ class NeighborExchangeV2UnifyMindIR : public PatternProcessPass { const CNodePtr &all_to_all_v) const; CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2, const CNodePtr &all_to_all_v) const; + std::vector MustExistPrimitiveName() const override; }; class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass { @@ -68,6 +71,7 @@ class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass { CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad, const CNodePtr &all_to_all_v, const std::vector &split_nodes, const std::vector &split_num) const; + std::vector MustExistPrimitiveName() const override; }; } // namespace opt diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.cc index 9b7613f2576..9d612bcf2c7 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.cc @@ -30,12 +30,27 @@ constexpr size_t kFtrlOutputNum = 3; constexpr size_t kMomentumOutputNum = 2; constexpr size_t kRMSPropOutputNum = 3; constexpr size_t kCenteredRMSPropOutputNum = 4; +constexpr auto kOptVar = "var"; +constexpr auto kOptAccum = "accum"; +constexpr auto kOptLinear = "linear"; +constexpr auto kOptGrad = "grad"; +constexpr auto kOptLr = "lr"; +constexpr auto kOptL1 = "l1"; +constexpr auto kOptL2 = "l2"; +constexpr auto kOptLrPower = "lr_power"; +constexpr auto kOptU = "u"; +constexpr auto kOptIndex = "index"; +constexpr auto kMomentum = "momentum"; +constexpr auto kInputs = "inputs"; +constexpr auto kMg = "mg"; +constexpr auto kMs = "ms"; +constexpr auto kMom = "mom"; +constexpr auto kRho = "rho"; +constexpr auto kEpsilon = "epsilon"; +constexpr auto kMOptimizer = "m_optimizer"; +constexpr auto kRTupleGet = "r_tuple_get"; -CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size, - const PatternProcessPass &pass) { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - +bool CheckNode(const AnfNodePtr &node) { auto cnode_ptr = node->cast(); MS_EXCEPTION_IF_NULL(cnode_ptr); @@ -43,89 +58,122 @@ CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const MS_EXCEPTION_IF_NULL(abstract); if (common::AnfAlgo::HasNodeAttr("optim_output_passed", cnode_ptr) && abstract->isa()) { - return nullptr; + return false; } + return true; +} + +AnfNodePtr BuildZero(const PatternMap &) { return NewValueNode(static_cast(0)); } +} // namespace + +AnfNodePtr BuildTupleGetFunc::operator()(const PatternMap &m, const AnfNodePtr &get_item) const { + auto node = m.Get(kMOptimizer); + MS_EXCEPTION_IF_NULL(node); + + auto cnode_ptr = node->cast(); + MS_EXCEPTION_IF_NULL(cnode_ptr); + + auto abstract = cnode_ptr->abstract(); + MS_EXCEPTION_IF_NULL(abstract); common::AnfAlgo::SetNodeAttr("optim_output_passed", MakeValue(true), cnode_ptr); std::vector abstract_list; - for (size_t i = 0; i < output_size; i++) { + for (size_t i = 0; i < output_size_; i++) { abstract_list.push_back(abstract->Clone()); } auto abstract_tuple = std::make_shared(abstract_list); cnode_ptr->set_abstract(abstract_tuple); - auto index = NewValueNode(static_cast(0)); - auto get_item = pass.NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}, graph); - MS_EXCEPTION_IF_NULL(get_item); - get_item->set_abstract(abstract->Clone()); return get_item; } -} // namespace -const BaseRef FtrlUnifyOutput::DefinePattern() const { - VarPtr var = std::make_shared(); - VarPtr accum = std::make_shared(); - VarPtr linear = std::make_shared(); - VarPtr grad = std::make_shared(); - VarPtr lr = std::make_shared(); - VarPtr l1 = std::make_shared(); - VarPtr l2 = std::make_shared(); - VarPtr lr_power = std::make_shared(); - VarPtr u = std::make_shared(); - VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u}); - return pattern; +bool FtrlUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const { + return CheckNode(node); } -const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const { - return ProcessOutput(graph, node, kFtrlOutputNum, *this); +void FtrlUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kOptVar) + .AddVar(kOptAccum) + .AddVar(kOptLinear) + .AddVar(kOptGrad) + .AddVar(kOptLr) + .AddVar(kOptL1) + .AddVar(kOptL2) + .AddVar(kOptLrPower) + .AddVar(kOptU) + .AddCNode(kMOptimizer, {prim::kPrimApplyFtrl, kOptVar, kOptAccum, kOptLinear, kOptGrad, kOptLr, kOptL1, kOptL2, + kOptLrPower, kOptU}); } -const BaseRef MomentumUnifyOutput::DefinePattern() const { - VarPtr var = std::make_shared(); - VarPtr accum = std::make_shared(); - VarPtr lr = std::make_shared(); - VarPtr grad = std::make_shared(); - VarPtr momentum = std::make_shared(); - VarPtr u = std::make_shared(); - VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u}); - return pattern; +void FtrlUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kOptIndex, BuildZero) + .AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kFtrlOutputNum)); } -const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - return ProcessOutput(graph, node, kMomentumOutputNum, *this); +bool MomentumUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const { + return CheckNode(node); } -const BaseRef RMSPropUnifyOutput::DefinePattern() const { - VarPtr inputs = std::make_shared(); - VectorRef pattern({prim::kPrimApplyRMSProp, inputs}); - return pattern; +void MomentumUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kOptVar) + .AddVar(kOptAccum) + .AddVar(kOptLr) + .AddVar(kOptGrad) + .AddVar(kMomentum) + .AddVar(kOptU) + .AddCNode(kMOptimizer, {prim::kPrimApplyMomentum, kOptVar, kOptAccum, kOptLr, kOptGrad, kMomentum, kOptU}); } -const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - return ProcessOutput(graph, node, kRMSPropOutputNum, *this); +void MomentumUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kOptIndex, BuildZero) + .AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kMomentumOutputNum)); } -const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const { - VarPtr var = std::make_shared(); - VarPtr mg = std::make_shared(); - VarPtr ms = std::make_shared(); - VarPtr mom = std::make_shared(); - VarPtr grad = std::make_shared(); - VarPtr lr = std::make_shared(); - VarPtr rho = std::make_shared(); - VarPtr momentum = std::make_shared(); - VarPtr epsilon = std::make_shared(); - VarPtr u = std::make_shared(); - VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u}); - return pattern; +bool RMSPropUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const { + return CheckNode(node); } -const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - return ProcessOutput(graph, node, kCenteredRMSPropOutputNum, *this); +void RMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddSeqVar(kInputs).AddCNode(kMOptimizer, {prim::kPrimApplyRMSProp, kInputs}); +} + +void RMSPropUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kOptIndex, BuildZero) + .AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kRMSPropOutputNum)); +} + +bool CenteredRMSPropUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, + const AnfNodePtr &node) const { + return CheckNode(node); +} + +void CenteredRMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern) + .AddVar(kOptVar) + .AddVar(kMg) + .AddVar(kMs) + .AddVar(kMom) + .AddVar(kOptGrad) + .AddVar(kOptLr) + .AddVar(kRho) + .AddVar(kMomentum) + .AddVar(kEpsilon) + .AddVar(kOptU) + .AddCNode(kMOptimizer, {prim::kPrimApplyCenteredRMSProp, kOptVar, kMg, kMs, kMom, kOptGrad, kOptLr, kRho, kMomentum, + kEpsilon, kOptU}); +} + +void CenteredRMSPropUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern) + .AddValueNode(kOptIndex, BuildZero) + .AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, + BuildTupleGetFunc(kCenteredRMSPropOutputNum)); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.h index b129a47e89c..ec5eec011e2 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/optimizer_unify_output.h @@ -16,42 +16,52 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_ -#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class FtrlUnifyOutput : public PatternProcessPass { +class BuildTupleGetFunc { public: - explicit FtrlUnifyOutput(bool multigraph = true) : PatternProcessPass("ftrl_unify_output", multigraph) {} + explicit BuildTupleGetFunc(const size_t output_size) : output_size_(output_size) {} + AnfNodePtr operator()(const PatternMap &m, const AnfNodePtr &get_item) const; + size_t output_size_; +}; +class FtrlUnifyOutput : public PatternToPatternPass { + public: + FtrlUnifyOutput() : PatternToPatternPass("ftrl_unify_output", true) {} ~FtrlUnifyOutput() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -class MomentumUnifyOutput : public PatternProcessPass { + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; +}; +class MomentumUnifyOutput : public PatternToPatternPass { public: - explicit MomentumUnifyOutput(bool multigraph = true) : PatternProcessPass("momentum_unify_output", multigraph) {} + MomentumUnifyOutput() : PatternToPatternPass("momentum_unify_output", true) {} ~MomentumUnifyOutput() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -class CenteredRMSPropUnifyOutput : public PatternProcessPass { + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; +}; +class CenteredRMSPropUnifyOutput : public PatternToPatternPass { public: - explicit CenteredRMSPropUnifyOutput(bool multigraph = true) - : PatternProcessPass("centered_rmsprop_unify_output", multigraph) {} + CenteredRMSPropUnifyOutput() : PatternToPatternPass("centered_rmsprop_unify_output", true) {} ~CenteredRMSPropUnifyOutput() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -class RMSPropUnifyOutput : public PatternProcessPass { + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; +}; +class RMSPropUnifyOutput : public PatternToPatternPass { public: - explicit RMSPropUnifyOutput(bool multigraph = true) : PatternProcessPass("rmsprop_unify_output", multigraph) {} + RMSPropUnifyOutput() : PatternToPatternPass("rmsprop_unify_output", true) {} ~RMSPropUnifyOutput() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.cc index f6d2ece223b..ce9c59faf86 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.cc @@ -28,6 +28,12 @@ namespace mindspore { namespace opt { +std::vector QuantDTypeCastAdjust::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(std::make_shared(kQuantDTypeCastOpName)->name()); + return ret; +} + const BaseRef QuantDTypeCastAdjust::DefinePattern() const { VarPtr Xs = std::make_shared(); auto prim = std::make_shared(kQuantDTypeCastOpName); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h index 229bf732360..4de2764f94e 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/quant_dtype_cast_adjust.h @@ -17,6 +17,7 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_QUANT_DTYPE_CAST_ADJUST_H_ #include +#include #include "backend/common/optimizer/optimizer.h" #include "backend/common/optimizer/helper.h" @@ -28,6 +29,9 @@ class QuantDTypeCastAdjust : public PatternProcessPass { ~QuantDTypeCastAdjust() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.cc index 0bbeb096352..69daf04ebdd 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.cc @@ -34,6 +34,10 @@ namespace opt { namespace { constexpr size_t kSliceGradInputTensorNum = 4; constexpr size_t kSliceGradCangjieInputTensorNum = 2; +constexpr auto kMSliceGrad = "m_slice_grad"; +constexpr auto kRPad = "r_pad"; +constexpr auto kX1 = "X1"; +constexpr auto kXs = "Xs"; std::vector GetInputXShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); @@ -47,19 +51,10 @@ std::vector GetTupleValue(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(value_node->value()); return GetValue>(value_node->value()); } -} // namespace -const BaseRef SliceGradUnifyMindIR::DefinePattern() const { - VarPtr Xs = std::make_shared(); - VectorRef slice_grad({std::make_shared("SliceGrad"), Xs}); - return slice_grad; -} - -const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); +AnfNodePtr BuildPad(const PatternMap &m, const AnfNodePtr &pad) { + auto node = m.Get(kMSliceGrad); MS_EXCEPTION_IF_NULL(node); - auto slice_grad = node->cast(); MS_EXCEPTION_IF_NULL(slice_grad); auto input_num = common::AnfAlgo::GetInputTensorNum(slice_grad); @@ -68,9 +63,6 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const << "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum << " or " << kSliceGradCangjieInputTensorNum << trace::DumpSourceLines(node); } - std::vector pad_inputs = {NewValueNode(std::make_shared(kPadDOpName)), - slice_grad->input(kIndex1)}; - auto pad = NewCNode(pad_inputs, graph); MS_EXCEPTION_IF_NULL(pad); pad->set_scope(slice_grad->scope()); pad->set_abstract(slice_grad->abstract()); @@ -80,12 +72,6 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const std::vector begins; std::vector sizes; if (input_num == kSliceGradInputTensorNum) { - auto begin_value = GetValueNode(slice_grad->input(kIndex3)); - auto size_value = GetValueNode(slice_grad->input(kIndex4)); - if (IsDynamic(x_shape) || begin_value == nullptr || size_value == nullptr || !begin_value->isa() || - !size_value->isa()) { - return nullptr; - } begins = GetTupleValue(slice_grad->input(kIndex3)); sizes = GetTupleValue(slice_grad->input(kIndex4)); } else { @@ -108,5 +94,31 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const return pad; } +} // namespace + +bool SliceGradUnifyMindIR::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const { + MS_EXCEPTION_IF_NULL(node); + auto slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(slice_grad); + auto input_num = common::AnfAlgo::GetInputTensorNum(slice_grad); + auto x_shape = GetInputXShape(slice_grad); + if (input_num == kSliceGradInputTensorNum) { + auto begin_value = GetValueNode(slice_grad->input(kIndex3)); + auto size_value = GetValueNode(slice_grad->input(kIndex4)); + if (IsDynamic(x_shape) || begin_value == nullptr || size_value == nullptr || !begin_value->isa() || + !size_value->isa()) { + return false; + } + } + return true; +} + +void SliceGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kX1).AddSeqVar(kXs).AddCNode(kMSliceGrad, {std::make_shared("SliceGrad"), kX1, kXs}); +} + +void SliceGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRPad, {std::make_shared(kPadDOpName), kX1}, BuildPad); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.h index 89e82c951f2..87f034b331d 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/slice_grad_unify_mindir.h @@ -16,17 +16,19 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SLICE_GRAD_UNIFY_MINDIR_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SLICE_GRAD_UNIFY_MINDIR_H_ -#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class SliceGradUnifyMindIR : public PatternProcessPass { +class SliceGradUnifyMindIR : public PatternToPatternPass { public: - explicit SliceGradUnifyMindIR(bool multigraph = true) : PatternProcessPass("slice_grad_unify_mindir", multigraph) {} + SliceGradUnifyMindIR() : PatternToPatternPass("slice_grad_unify_mindir", true) {} ~SliceGradUnifyMindIR() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.cc index e19bab98415..e0d4d7e4096 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.cc @@ -31,19 +31,14 @@ constexpr size_t kBlockShapeDimNum = 2; constexpr auto kAttrBlockShape = "block_shape"; constexpr auto kAttrPaddings = "paddings"; constexpr auto kAttrCrops = "crops"; -} // namespace - -const BaseRef SpaceToBatchNDAttrUpdate::DefinePattern() const { - VarPtr X = std::make_shared(); - VectorRef pattern({prim::kPrimSpaceToBatchND, X}); - return pattern; -} - -const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); +constexpr auto kV = "V"; +constexpr auto kMSpace = "m_space"; +constexpr auto kRSpace = "r_space"; +constexpr auto kMBatch = "m_batch"; +constexpr auto kRBatch = "r_batch"; +AnfNodePtr BuildSpace(const PatternMap &m, const AnfNodePtr &default_node) { + auto node = m.Get(kMSpace); auto block_shape = common::AnfAlgo::GetNodeAttr>(node, kAttrBlockShape); if (block_shape.size() == kBlockShapeDimNum) { (void)block_shape.insert(block_shape.cbegin(), 1); @@ -57,17 +52,8 @@ const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, co return node; } -const BaseRef BatchToSpaceNDAttrUpdate::DefinePattern() const { - VarPtr X = std::make_shared(); - VectorRef pattern({prim::kPrimBatchToSpaceND, X}); - return pattern; -} - -const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(node); - +AnfNodePtr BuildBatch(const PatternMap &m, const AnfNodePtr &default_node) { + auto node = m.Get(kMBatch); auto block_shape = common::AnfAlgo::GetNodeAttr>(node, kAttrBlockShape); if (block_shape.size() == kBlockShapeDimNum) { (void)block_shape.insert(block_shape.cbegin(), 1); @@ -80,5 +66,30 @@ const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, co } return node; } +} // namespace + +bool SpaceToBatchNDAttrUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { + return true; +} + +void SpaceToBatchNDAttrUpdate::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kV).AddCNode(kMSpace, {prim::kPrimSpaceToBatchND, kV}); +} + +void SpaceToBatchNDAttrUpdate::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRSpace, {prim::kPrimSpaceToBatchND, kV}, BuildSpace); +} + +bool BatchToSpaceNDAttrUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { + return true; +} + +void BatchToSpaceNDAttrUpdate::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddVar(kV).AddCNode(kMBatch, {prim::kPrimBatchToSpaceND, kV}); +} + +void BatchToSpaceNDAttrUpdate::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRBatch, {prim::kPrimBatchToSpaceND, kV}, BuildBatch); +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.h index 074878fb06f..e9b2826b2ef 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/space_batch_nd_attr_update.h @@ -17,26 +17,26 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_ #include -#include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class SpaceToBatchNDAttrUpdate : public PatternProcessPass { +class SpaceToBatchNDAttrUpdate : public PatternToPatternPass { public: - explicit SpaceToBatchNDAttrUpdate(bool multigraph = true) - : PatternProcessPass("space_to_batch_nd_attr_update", multigraph) {} + SpaceToBatchNDAttrUpdate() : PatternToPatternPass("space_to_batch_nd_attr_update", true) {} ~SpaceToBatchNDAttrUpdate() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; -class BatchToSpaceNDAttrUpdate : public PatternProcessPass { +class BatchToSpaceNDAttrUpdate : public PatternToPatternPass { public: - explicit BatchToSpaceNDAttrUpdate(bool multigraph = true) - : PatternProcessPass("batch_to_space_nd_attr_update", multigraph) {} + BatchToSpaceNDAttrUpdate() : PatternToPatternPass("batch_to_space_nd_attr_update", true) {} ~BatchToSpaceNDAttrUpdate() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc index 6b3cacfb4f0..1b06e071eeb 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc @@ -459,6 +459,12 @@ CNodePtr CreateMulInput(const FuncGraphPtr &graph, const CNodePtr &mul_node, con } } // namespace +std::vector SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name()); + return ret; +} + const BaseRef SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const { VarPtr x1 = std::make_shared(); VarPtr x2 = std::make_shared(); @@ -619,6 +625,13 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process } } +std::vector PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name()); + ret.emplace_back(prim::kPrimMul->name()); + return ret; +} + const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const { VarPtr x1 = std::make_shared(); VarPtr x2 = std::make_shared(); @@ -654,6 +667,14 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro return new_mul_node; } +std::vector PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::MustExistPrimitiveName() const { + std::vector ret; + ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name()); + ret.emplace_back(prim::kPrimCast->name()); + ret.emplace_back(prim::kPrimMul->name()); + return ret; +} + const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const { VarPtr x1 = std::make_shared(); VarPtr x2 = std::make_shared(); diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h index edfaa721a13..1a3a0d76cce 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.h @@ -19,6 +19,7 @@ #include #include +#include #include "backend/common/optimizer/optimizer.h" namespace mindspore { @@ -31,6 +32,9 @@ class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass ~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass { @@ -67,6 +71,9 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public Patter ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProcessPass { @@ -76,6 +83,9 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt ~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + + private: + std::vector MustExistPrimitiveName() const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.cc b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.cc index 86b6a2e9a0e..eda301e088b 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.cc +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.cc @@ -24,30 +24,45 @@ namespace mindspore { namespace opt { -const BaseRef StridedSliceGradUpdateInputNames::DefinePattern() const { - VarPtr Xs = std::make_shared(); - auto strided_slice_grad_prim = std::make_shared(kStridedSliceGradOpName); - return VectorRef({strided_slice_grad_prim, Xs}); -} +namespace { +constexpr auto kXs = "Xs"; +constexpr auto kMSliceGrad = "m_slice_grad"; +constexpr auto kRSliceGrad = "r_slice_grad"; -const AnfNodePtr StridedSliceGradUpdateInputNames::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(graph); +AnfNodePtr BuildSliceGrad(const PatternMap &m, const AnfNodePtr &) { + auto node = m.Get(kMSliceGrad); MS_EXCEPTION_IF_NULL(node); auto strided_slice_grad = node->cast(); MS_EXCEPTION_IF_NULL(strided_slice_grad); const size_t shapex_index = 1; + auto primitive = common::AnfAlgo::GetCNodePrimitive(strided_slice_grad); + MS_EXCEPTION_IF_NULL(primitive); + auto input_names_ptr = primitive->GetAttr(kAttrInputNames); + MS_EXCEPTION_IF_NULL(input_names_ptr); + auto input_names_vec = GetValue>(input_names_ptr); + input_names_vec[shapex_index] = "shape"; + common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_vec), strided_slice_grad); + return strided_slice_grad; +} +} // namespace + +bool StridedSliceGradUpdateInputNames::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, + const AnfNodePtr &node) const { + auto strided_slice_grad = node->cast(); + MS_EXCEPTION_IF_NULL(strided_slice_grad); if (common::AnfAlgo::IsDynamicShape(strided_slice_grad)) { - auto primitive = common::AnfAlgo::GetCNodePrimitive(strided_slice_grad); - MS_EXCEPTION_IF_NULL(primitive); - auto input_names_ptr = primitive->GetAttr(kAttrInputNames); - MS_EXCEPTION_IF_NULL(input_names_ptr); - auto input_names_vec = GetValue>(input_names_ptr); - input_names_vec[shapex_index] = "shape"; - common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_vec), strided_slice_grad); + return true; } - return nullptr; + return false; +} + +void StridedSliceGradUpdateInputNames::DefineSrcPattern(SrcPattern *src_pattern) { + (*src_pattern).AddSeqVar(kXs).AddCNode(kMSliceGrad, {std::make_shared(kStridedSliceGradOpName), kXs}); +} + +void StridedSliceGradUpdateInputNames::DefineDstPattern(DstPattern *dst_pattern) { + (*dst_pattern).AddCNode(kRSliceGrad, {std::make_shared(kStridedSliceGradOpName), kXs}, BuildSliceGrad); } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.h b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.h index 2e78329a999..a1f2052ba74 100644 --- a/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.h +++ b/mindspore/ccsrc/plugin/device/ascend/optimizer/mindir/update_input_names_strided_slice_grad.h @@ -16,18 +16,19 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_UPDATE_INPUT_NAMES_STRIDED_SLICE_GRAD_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_UPDATE_INPUT_NAMES_STRIDED_SLICE_GRAD_H_ -#include #include "backend/common/optimizer/optimizer.h" +#include "backend/common/optimizer/pattern_to_pattern.h" namespace mindspore { namespace opt { -class StridedSliceGradUpdateInputNames : public PatternProcessPass { +class StridedSliceGradUpdateInputNames : public PatternToPatternPass { public: - explicit StridedSliceGradUpdateInputNames(bool multigraph = true) - : PatternProcessPass("update_strided_slice_grad_input_names", multigraph) {} + StridedSliceGradUpdateInputNames() : PatternToPatternPass("update_strided_slice_grad_input_names", true) {} ~StridedSliceGradUpdateInputNames() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override; + + void DefineSrcPattern(SrcPattern *src_pattern) override; + void DefineDstPattern(DstPattern *dst_pattern) override; + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 230b637a6ce..73efcd81aae 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -211,6 +211,7 @@ FuncGraphManager::FuncGraphManager(const std::vector &roots, bool void FuncGraphManager::Reset() { func_graphs_ = FuncGraphSet(); + func_graphs_index_ = FuncGraphIndexMap(); all_nodes_ = AnfNodeSet(); node_users_ = NodeUsersMap(); signals_ = std::make_shared(); @@ -285,6 +286,14 @@ FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) c return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; } +const FuncGraphIndexPtr &FuncGraphManager::func_graph_index(const FuncGraphPtr &fg) const { + auto iter = func_graphs_index_.find(fg); + if (iter == func_graphs_index_.end()) { + MS_LOG(EXCEPTION) << "Func graph: " << fg->ToString() << " is not add FuncGraphIndexMap."; + } + return func_graphs_index_.at(fg); +} + bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(recursive_); @@ -346,6 +355,8 @@ void FuncGraphManager::AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root (void)new_nodes.emplace_back(std::move(return_node)); } + func_graphs_index_.emplace(func_graph, std::make_shared()); + // Acquire all nodes from func_graph. AcquireNodes(std::move(new_nodes)); } @@ -362,6 +373,7 @@ void FuncGraphManager::Clear() noexcept { } func_graphs_.clear(); + func_graphs_index_.clear(); all_nodes_.clear(); node_users_.clear(); roots_.clear(); diff --git a/mindspore/core/ir/manager.h b/mindspore/core/ir/manager.h index 406dd1c26d6..b8da704a9ad 100644 --- a/mindspore/core/ir/manager.h +++ b/mindspore/core/ir/manager.h @@ -53,10 +53,13 @@ using ChangePtr = std::unique_ptr; class FuncGraphTransaction; class FuncGraphManager; +class FuncGraphPassIndex; using FuncGraphManagerPtr = std::shared_ptr; +using FuncGraphIndexPtr = std::shared_ptr; using AnfNodeIndexSet = CompactSet>; using NodeUsersMap = mindspore::HashMap>; +using FuncGraphIndexMap = mindspore::HashMap; using FuncGraphSetPair = std::pair; using FuncGraphSetPtr = std::shared_ptr; @@ -80,6 +83,21 @@ using CNodeIndexPair = std::pair; using CNodeIndexPairPtr = std::shared_ptr; using FuncGraphToFuncGraphSetMap = OrderedMap; +// For Fast Pass +class FuncGraphPassIndex { + public: + FuncGraphPassIndex() : has_gen_index_(false) {} + void set_has_gen_index(bool is_gen_index) { has_gen_index_ = is_gen_index; } + bool has_gen_index() const { return has_gen_index_; } + mindspore::HashMap node_to_fg_; + mindspore::HashMap> name_to_cnode_; + mindspore::HashMap> subgraph_out_caller_map_; + mindspore::HashMap node_degree_; + + private: + bool has_gen_index_; +}; + // analysis base class, graphs analysis which need dynamic compute by DepCollector in each read class DepComputer { public: @@ -331,6 +349,8 @@ class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this> recursive_graphs(const FuncGraphPtr &fg) const; @@ -359,8 +379,9 @@ class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this signals_; diff --git a/mindspore/core/ir/value.cc b/mindspore/core/ir/value.cc index d4dd723a2ec..ad05da61779 100644 --- a/mindspore/core/ir/value.cc +++ b/mindspore/core/ir/value.cc @@ -132,7 +132,7 @@ bool FP32Imm::operator==(const Value &other) const { } } bool FP32Imm::operator==(const FP32Imm &other) const { - if (std::isinf(v_) && std::isinf(other.v_)) { + if ((std::isinf(v_) && std::isinf(other.v_)) || (std::isnan(v_) && std::isnan(other.v_))) { return true; } return fabs(v_ - other.v_) < DBL_EPSILON; @@ -186,7 +186,7 @@ std::string ValueSequence::DumpText() const { } bool FP64Imm::operator==(const FP64Imm &other) const { - if (std::isinf(v_) && std::isinf(other.v_)) { + if ((std::isinf(v_) && std::isinf(other.v_)) || (std::isnan(v_) && std::isnan(other.v_))) { return true; } return fabs(v_ - other.v_) < DBL_EPSILON; diff --git a/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc b/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc new file mode 100644 index 00000000000..a1f9d635f05 --- /dev/null +++ b/tests/ut/cpp/pre_activate/common/fast_pattern_to_pattern_pass_test.cc @@ -0,0 +1,617 @@ +/** + * Copyright 2023 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pattern_to_pattern_pass_utils.h" +#include "backend/common/optimizer/node_pass.h" + +namespace mindspore { +namespace opt { +namespace { +const auto kZero = 0; +const auto kOne = 1; +const auto kTwo = 2; +const auto kThree = 3; + +const auto kA = "a"; +const auto kB = "b"; +const auto kC = "c"; +const auto kD = "d"; +const auto kE = "e"; +const auto kAAddB = "a_add_b"; +const auto kCAddD = "c_add_d"; +const auto kMul = "mul"; +const auto kAdd = "add"; + +class TestFastMul0 : public PatternToPatternPass { + // a*b + a*c -> a*(b+c) + public: + explicit TestFastMul0() : PatternToPatternPass("test_fast_mul0") {} + ~TestFastMul0() override = default; + + void DefineSrcPattern(SrcPattern *src_pattern) override { + (*src_pattern) + .AddVar("a") + .AddVar("b") + .AddVar("c") + .AddCNode("ab", {std::make_shared(kMulOpName), "a", "b"}) + .AddCNode("ac", {std::make_shared(kMulOpName), "a", "c"}) + .AddCNode("add", {std::make_shared(kAddOpName), "ab", "ac"}); + } + void DefineDstPattern(DstPattern *dst_pattern) override { + (*dst_pattern) + .AddCNode("bc", {std::make_shared(kAddOpName), "b", "c"}) + .AddCNode("mul", {std::make_shared(kMulOpName), "a", "bc"}); + } + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; } +}; + +class TestFastMul1 : public PatternToPatternPass { + // a*b + c*d -> a*c + public: + explicit TestFastMul1() : PatternToPatternPass("test_fast_mul1") {} + ~TestFastMul1() override = default; + + void DefineSrcPattern(SrcPattern *src_pattern) override { + (*src_pattern) + .AddVar("a") + .AddVar("b") + .AddVar("c") + .AddVar("d") + .AddCNode("ab", {std::make_shared(kMulOpName), "a", "b"}) + .AddCNode("cd", {std::make_shared(kMulOpName), "c", "d"}) + .AddCNode("add", {std::make_shared(kAddOpName), "ab", "cd"}); + } + void DefineDstPattern(DstPattern *dst_pattern) override { + (*dst_pattern).AddCNode("ad", {std::make_shared(kMulOpName), "a", "d"}); + } + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; } +}; + +class TestFastMul2 : public PatternToPatternPass { + // a*b -> b*a + public: + explicit TestFastMul2() : PatternToPatternPass("test_fast_mul2") {} + ~TestFastMul2() override = default; + + void DefineSrcPattern(SrcPattern *src_pattern) override { + (*src_pattern).AddSeqVar("Sv").AddCNode("ab", {std::make_shared(kMulOpName), "Sv"}); + } + void DefineDstPattern(DstPattern *dst_pattern) override { + auto ba = Unpacking("Sv"); + auto ab = Unpacking("Sv"); + ba[0] = ab[1]; + ba[1] = ab[0]; + (*dst_pattern).AddCNode("mul", {std::make_shared(kMulOpName), ba}); + } + bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; } +}; +} // namespace + +class TestFastPatternToPatternPass : public UT::Common { + public: + TestFastPatternToPatternPass() : fg_(std::make_shared()){}; + + public: + FuncGraphPtr fg_; +}; + +/// Feature: Fast PatternToPattern Pass +/// Description: Fast PatternToPattern Pass rewrite graph +/// Expectation: Get correct Graph +TEST_F(TestFastPatternToPatternPass, Mul0) { + // a*b + a*c -> a*(b+c) + // init + auto check = CheckPattern(); + auto pass = TestFastMul0(); + + // build func graph + auto a = std::make_shared(fg_); + auto b = std::make_shared(fg_); + auto c = std::make_shared(fg_); + AnfNodePtr ab = + std::make_shared(std::vector{NewValueNode(std::make_shared(kMulOpName)), a, b}, fg_); + AnfNodePtr ac = + std::make_shared(std::vector{NewValueNode(std::make_shared(kMulOpName)), a, c}, fg_); + AnfNodePtr add = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), ab, ac}, fg_); + + fg_->set_output(add); + auto manager = MakeManager({fg_}); + if (manager) { + manager->AddFuncGraph(fg_); + fg_->set_manager(manager); + } + auto func_graph_index = manager->func_graph_index(fg_); + GenIndex(fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set.size() == 1); + ASSERT_TRUE(mul_set.size() == 2); + ASSERT_TRUE(add_set.find(add) != add_set.end()); + ASSERT_TRUE(mul_set.find(ab) != mul_set.end()); + ASSERT_TRUE(mul_set.find(ac) != mul_set.end()); + + auto new_node = pass.Run(fg_, add); + ASSERT_NE(new_node, nullptr); + (void)manager->Replace(add, new_node); + pass.AfterProcess(add, new_node, fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set_2.size() == 1); + ASSERT_TRUE(mul_set_2.size() == 1); + ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end()); + ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end()); + + // build pattern + check.src_pattern_.AddVar("a") + .AddVar("b") + .AddVar("c") + .AddCNode("bc", {std::make_shared(kAddOpName), "b", "c"}) + .AddCNode("mul", {std::make_shared(kMulOpName), "a", "bc"}); + + // pattern engine + ASSERT_TRUE(check.build_pattern_map(new_node)); + + // check + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c)); + ASSERT_EQ(check.m_->Get("bc")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); + ASSERT_EQ(check.m_->Get("mul")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); +} + +/// Feature: Fast PatternToPattern Pass +/// Description: Fast PatternToPattern Pass rewrite graph +/// Expectation: Get correct Graph +TEST_F(TestFastPatternToPatternPass, Mul0NotRoot) { + // (a*b + a*c) + d -> a*(b+c) + d + // init + auto check = CheckPattern(); + auto pass = TestFastMul0(); + + // build func graph + auto a = std::make_shared(fg_); + auto b = std::make_shared(fg_); + auto c = std::make_shared(fg_); + auto d = std::make_shared(fg_); + AnfNodePtr ab = + std::make_shared(std::vector{NewValueNode(std::make_shared(kMulOpName)), a, b}, fg_); + AnfNodePtr ac = + std::make_shared(std::vector{NewValueNode(std::make_shared(kMulOpName)), a, c}, fg_); + AnfNodePtr add = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), ab, ac}, fg_); + AnfNodePtr add1 = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), add, d}, fg_); + + fg_->set_output(add1); + auto manager = MakeManager({fg_}); + if (manager) { + manager->AddFuncGraph(fg_); + fg_->set_manager(manager); + } + auto func_graph_index = manager->func_graph_index(fg_); + GenIndex(fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set.size() == 2); + ASSERT_TRUE(mul_set.size() == 2); + ASSERT_TRUE(add_set.find(add1) != add_set.end()); + ASSERT_TRUE(add_set.find(add) != add_set.end()); + ASSERT_TRUE(mul_set.find(ab) != mul_set.end()); + ASSERT_TRUE(mul_set.find(ac) != mul_set.end()); + + auto new_node = pass.Run(fg_, add); + ASSERT_NE(new_node, nullptr); + (void)manager->Replace(add, new_node); + pass.AfterProcess(add, new_node, fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set_2.size() == 2); + ASSERT_TRUE(mul_set_2.size() == 1); + ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end()); + ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end()); + ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end()); + + // build pattern + check.src_pattern_.AddVar("a") + .AddVar("b") + .AddVar("c") + .AddVar("d") + .AddCNode("bc", {std::make_shared(kAddOpName), "b", "c"}) + .AddCNode("mul", {std::make_shared(kMulOpName), "a", "bc"}) + .AddCNode("add1", {std::make_shared(kAddOpName), "mul", "d"}); + + // pattern engine + ASSERT_TRUE(check.build_pattern_map(add1)); + + // check + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d)); + + ASSERT_EQ(check.m_->Get("bc")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(1), b)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast()->input(2), c)); + + ASSERT_EQ(check.m_->Get("mul")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(1), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast()->input(2), check.m_->Get("bc"))); + + ASSERT_EQ(check.m_->Get("add1")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("mul"))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(2), d)); +} + +/// Feature: Fast PatternToPattern Pass +/// Description: Fast PatternToPattern Pass rewrite graph +/// Expectation: Get correct Graph +TEST_F(TestFastPatternToPatternPass, Mul1) { + // (a * (b1 + d) + (c1 * c2) * d) + e -> (a + d) + e + // init + auto check = CheckPattern(); + auto pass = TestFastMul1(); + + // build func graph + auto a = std::make_shared(fg_); + auto b = std::make_shared(fg_); + auto c1 = std::make_shared(fg_); + auto c2 = std::make_shared(fg_); + auto d = std::make_shared(fg_); + auto e = std::make_shared(fg_); + + AnfNodePtr b_add_d = + std::make_shared(std::vector{NewValueNode(std::make_shared(kAddOpName)), b, d}, fg_); + AnfNodePtr c1_mul_c2 = std::make_shared( + std::vector{NewValueNode(std::make_shared(kMulOpName)), c1, c2}, fg_); + AnfNodePtr a_mul = std::make_shared( + std::vector{NewValueNode(std::make_shared(kMulOpName)), a, b_add_d}, fg_); + AnfNodePtr d_mul = std::make_shared( + std::vector{NewValueNode(std::make_shared(kMulOpName)), c1_mul_c2, d}, fg_); + AnfNodePtr add = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), a_mul, d_mul}, fg_); + AnfNodePtr add1 = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), add, e}, fg_); + + fg_->set_output(add1); + auto manager = MakeManager({fg_}); + if (manager) { + manager->AddFuncGraph(fg_); + fg_->set_manager(manager); + } + auto func_graph_index = manager->func_graph_index(fg_); + GenIndex(fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1); + + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 2); + ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set.size() == 3); + ASSERT_TRUE(mul_set.size() == 3); + ASSERT_TRUE(add_set.find(add1) != add_set.end()); + ASSERT_TRUE(add_set.find(add) != add_set.end()); + ASSERT_TRUE(add_set.find(b_add_d) != add_set.end()); + ASSERT_TRUE(mul_set.find(a_mul) != mul_set.end()); + ASSERT_TRUE(mul_set.find(d_mul) != mul_set.end()); + ASSERT_TRUE(mul_set.find(c1_mul_c2) != mul_set.end()); + + auto new_node = pass.Run(fg_, add); + ASSERT_NE(new_node, nullptr); + (void)manager->Replace(add, new_node); + pass.AfterProcess(add, new_node, fg_, func_graph_index); + + ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("ad")) == 1); + + ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 0); + ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1); + ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1); + + ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end()); + ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end()); + + auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName]; + auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set_2.size() == 1); + ASSERT_TRUE(mul_set_2.size() == 1); + ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end()); + ASSERT_TRUE(mul_set_2.find(pass.m_->Get("ad")) != mul_set_2.end()); + + // build pattern + check.src_pattern_.AddVar("a") + .AddVar("d") + .AddVar("e") + .AddCNode("ad", {std::make_shared(kMulOpName), "a", "d"}) + .AddCNode("add1", {std::make_shared(kAddOpName), "ad", "e"}); + + // pattern engine + ASSERT_TRUE(check.build_pattern_map(add1)); + + // check + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("e"), e)); + + ASSERT_EQ(check.m_->Get("ad")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(0), + NewValueNode(std::make_shared(kMulOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(1), a)); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast()->input(2), d)); + + ASSERT_EQ(check.m_->Get("add1")->cast()->inputs().size(), 3); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(0), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(1), check.m_->Get("ad"))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast()->input(2), e)); +} + +namespace { +void Check0(const FuncGraphIndexPtr &fg, const std::map &node_map) { + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne); + + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne); + + ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo); + ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end()); + ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end()); + + auto &add_set = fg->name_to_cnode_[kAddOpName]; + auto &mul_set = fg->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set.size() == kThree); + ASSERT_TRUE(mul_set.size() == kOne); + ASSERT_TRUE(add_set.find(node_map.at(kAdd)) != add_set.end()); + ASSERT_TRUE(add_set.find(node_map.at(kAAddB)) != add_set.end()); + ASSERT_TRUE(add_set.find(node_map.at(kCAddD)) != add_set.end()); + ASSERT_TRUE(mul_set.find(node_map.at(kMul)) != mul_set.end()); +} +void Check1(const TestFastMul2 &pass, const FuncGraphIndexPtr &fg, const std::map &node_map) { + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kZero); + ASSERT_TRUE(fg->node_degree_.at(pass.m_->Get(kMul)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne); + + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne); + ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne); + + ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo); + ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end()); + ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end()); + + auto &add_set_2 = fg->name_to_cnode_[kAddOpName]; + auto &mul_set_2 = fg->name_to_cnode_[kMulOpName]; + + ASSERT_TRUE(add_set_2.size() == kThree); + ASSERT_TRUE(mul_set_2.size() == kOne); + ASSERT_TRUE(add_set_2.find(node_map.at(kAAddB)) != add_set_2.end()); + ASSERT_TRUE(add_set_2.find(node_map.at(kCAddD)) != add_set_2.end()); + ASSERT_TRUE(mul_set_2.find(pass.m_->Get(kMul)) != mul_set_2.end()); +} + +void Check2(const CheckPattern &check, const std::map &node_map) { + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kA), node_map.at(kA))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kB), node_map.at(kB))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kC), node_map.at(kC))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kD), node_map.at(kD))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kE), node_map.at(kE))); + + ASSERT_EQ(check.m_->Get(kAAddB)->cast()->inputs().size(), kThree); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kOne), node_map.at(kA))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast()->input(kTwo), node_map.at(kB))); + + ASSERT_EQ(check.m_->Get(kCAddD)->cast()->inputs().size(), kThree); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kOne), node_map.at(kC))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast()->input(kTwo), node_map.at(kD))); + + ASSERT_EQ(check.m_->Get(kMul)->cast()->inputs().size(), kThree); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kZero), + NewValueNode(std::make_shared(kMulOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kOne), node_map.at(kCAddD))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast()->input(kTwo), node_map.at(kAAddB))); + + ASSERT_EQ(check.m_->Get(kAdd)->cast()->inputs().size(), kThree); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kZero), + NewValueNode(std::make_shared(kAddOpName)))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kOne), check.m_->Get(kMul))); + ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast()->input(kTwo), node_map.at(kE))); +} +} // namespace + +/// Feature: Fast PatternToPattern Pass +/// Description: Fast PatternToPattern Pass rewrite graph +/// Expectation: Get correct Graph +TEST_F(TestFastPatternToPatternPass, Mul2) { + // ((a + b) * (c + d)) + e -> ((c + d) * (a + b)) + e + // init + auto check = CheckPattern(); + auto pass = TestFastMul2(); + + // build func graph + auto a = std::make_shared(fg_); + auto b = std::make_shared(fg_); + auto c = std::make_shared(fg_); + auto d = std::make_shared(fg_); + auto e = std::make_shared(fg_); + + AnfNodePtr a_add_b = + std::make_shared(std::vector{NewValueNode(std::make_shared(kAddOpName)), a, b}, fg_); + AnfNodePtr c_add_d = + std::make_shared(std::vector{NewValueNode(std::make_shared(kAddOpName)), c, d}, fg_); + AnfNodePtr mul = std::make_shared( + std::vector{NewValueNode(std::make_shared(kMulOpName)), a_add_b, c_add_d}, fg_); + AnfNodePtr add = std::make_shared( + std::vector{NewValueNode(std::make_shared(kAddOpName)), mul, e}, fg_); + + std::map node_map; + node_map.emplace("a", a); + node_map.emplace("b", b); + node_map.emplace("c", c); + node_map.emplace("d", d); + node_map.emplace("e", e); + node_map.emplace("a_add_b", a_add_b); + node_map.emplace("c_add_d", c_add_d); + node_map.emplace("mul", mul); + node_map.emplace("add", add); + + fg_->set_output(add); + auto manager = MakeManager({fg_}); + if (manager) { + manager->AddFuncGraph(fg_); + fg_->set_manager(manager); + } + auto func_graph_index = manager->func_graph_index(fg_); + GenIndex(fg_, func_graph_index); + + Check0(func_graph_index, node_map); + auto new_node = pass.Run(fg_, mul); + ASSERT_NE(new_node, nullptr); + (void)manager->Replace(mul, new_node); + pass.AfterProcess(mul, new_node, fg_, func_graph_index); + Check1(pass, func_graph_index, node_map); + + // build pattern + check.src_pattern_.AddVar("a") + .AddVar("b") + .AddVar("c") + .AddVar("d") + .AddVar("e") + .AddCNode("a_add_b", {std::make_shared(kAddOpName), "a", "b"}) + .AddCNode("c_add_d", {std::make_shared(kAddOpName), "c", "d"}) + .AddCNode("mul", {std::make_shared(kMulOpName), "c_add_d", "a_add_b"}) + .AddCNode("add", {std::make_shared(kAddOpName), "mul", "e"}); + + // pattern engine + ASSERT_TRUE(check.build_pattern_map(add)); + Check2(check, node_map); +} +} // namespace opt +} // namespace mindspore diff --git a/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc index 05301030a97..060e38735cc 100644 --- a/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc +++ b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_test.cc @@ -14,62 +14,11 @@ * limitations under the License. */ -#include -#include -#include "common/common_test.h" -#define private public -#define protected public -#include "backend/common/optimizer/pattern_to_pattern.h" -#undef private -#undef protected - -#include "mindspore/core/ops/core_ops.h" -#include "ir/anf.h" -#include "ir/value.h" -#include "include/common/utils/utils.h" -#include "backend/common/session/anf_runtime_algorithm.h" +#include "pattern_to_pattern_pass_utils.h" namespace mindspore { namespace opt { -class TestPatternToPatternPass : public UT::Common { - public: - TestPatternToPatternPass() : fg_(std::make_shared()){}; - - public: - FuncGraphPtr fg_; -}; - -class CheckPattern { - public: - CheckPattern() - : m_(std::make_shared()), - src_pattern_(SrcPattern(m_)), - pattern_engine_(PatternEngine(std::make_shared())), - primitive_vars_(std::make_shared()), - equiv_(std::make_shared()){}; - bool build_pattern_map(const AnfNodePtr &node) { - VarPtr root_g = std::make_shared("RootG"); - auto src_pattern_root = SexpToNode(src_pattern_.GetRoot(), root_g, primitive_vars_.get(), multigraph_); - auto primitive = GetCNodePrimitive(src_pattern_root); - if (IsPrimitiveCNode(node, primitive)) { - MS_EXCEPTION_IF_NULL(primitive_vars_); - MS_EXCEPTION_IF_NULL(equiv_); - equiv_->clear(); - EquivPtr equiv = pattern_engine_.Match(src_pattern_root, node, *primitive_vars_, equiv_); - if (equiv != nullptr && !equiv->empty()) { - return src_pattern_.build_pattern_map(node, equiv); - } - } - return false; - } - PatternMapPtr m_; - SrcPattern src_pattern_; - PatternEngine pattern_engine_; - PrimitiveVarMapPtr primitive_vars_; - EquivPtr equiv_; - bool multigraph_ = true; -}; - +namespace { class TestMul0 : public PatternToPatternPass { // a*b + a*c -> a*(b+c) public: @@ -227,6 +176,15 @@ class TestError1 : public PatternToPatternPass { } bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; } }; +} // namespace + +class TestPatternToPatternPass : public UT::Common { + public: + TestPatternToPatternPass() : fg_(std::make_shared()){}; + + public: + FuncGraphPtr fg_; +}; /// Feature: PatternToPattern Pass /// Description: PatternToPattern Pass rewrite graph @@ -422,8 +380,7 @@ TEST_F(TestPatternToPatternPass, Null) { AnfNodePtr add = std::make_shared( std::vector{NewValueNode(std::make_shared(kAddOpName)), ab, ac}, fg_); - auto new_node = pass.Run(fg_, add); - ASSERT_EQ(new_node, nullptr); + EXPECT_THROW(pass.Run(fg_, add), std::runtime_error); } /// Feature: PatternToPattern Pass diff --git a/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_utils.h b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_utils.h new file mode 100644 index 00000000000..3e24214d85a --- /dev/null +++ b/tests/ut/cpp/pre_activate/common/pattern_to_pattern_pass_utils.h @@ -0,0 +1,70 @@ +/** +* Copyright 2023 Huawei Technologies Co., Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. + */ + +#ifndef MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_ +#define MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_ + +#include +#include +#include "common/common_test.h" +#include "mindspore/core/ops/core_ops.h" +#include "ir/anf.h" +#include "ir/value.h" +#include "include/common/utils/utils.h" +#include "backend/common/session/anf_runtime_algorithm.h" + +#define private public +#define protected public +#include "backend/common/optimizer/pattern_to_pattern.h" +#undef private +#undef protected + +namespace mindspore { +namespace opt { +class CheckPattern { + public: + CheckPattern() + : m_(std::make_shared()), + src_pattern_(SrcPattern(m_)), + pattern_engine_(PatternEngine(std::make_shared())), + primitive_vars_(std::make_shared()), + equiv_(std::make_shared()){}; + bool build_pattern_map(const AnfNodePtr &node) { + VarPtr root_g = std::make_shared("RootG"); + auto src_pattern_root = SexpToNode(src_pattern_.GetRoot(), root_g, primitive_vars_.get(), multigraph_); + auto primitive = GetCNodePrimitive(src_pattern_root); + if (IsPrimitiveCNode(node, primitive)) { + MS_EXCEPTION_IF_NULL(primitive_vars_); + MS_EXCEPTION_IF_NULL(equiv_); + equiv_->clear(); + EquivPtr equiv = pattern_engine_.Match(src_pattern_root, node, *primitive_vars_, equiv_); + if (equiv != nullptr && !equiv->empty()) { + return src_pattern_.build_pattern_map(node, equiv); + } + } + return false; + } + PatternMapPtr m_; + SrcPattern src_pattern_; + PatternEngine pattern_engine_; + PrimitiveVarMapPtr primitive_vars_; + EquivPtr equiv_; + bool multigraph_ = true; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_