forked from mindspore-Ecosystem/mindspore
!48052 Optimize Ascend Backend Pass
Merge pull request !48052 from 王禹程/dyn_pass_ci
This commit is contained in:
commit
92dd278d21
|
@ -45,6 +45,10 @@ FuncGraphPtr GraphOptimizer::Optimize(const FuncGraphPtr &func_graph, bool run_o
|
|||
std::vector<FuncGraphPtr> func_graphs;
|
||||
func_graphs.push_back(func_graph);
|
||||
(void)TopoSort(func_graph->get_return());
|
||||
auto func_graph_index = manager->func_graph_index(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
func_graph_index->set_has_gen_index(false);
|
||||
|
||||
return func_graph;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -704,9 +704,17 @@ bool AnfEqual(const BaseRef &a, const BaseRef &b) {
|
|||
if (b_value_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Value ptr is nullptr.";
|
||||
}
|
||||
|
||||
if (a_value_ptr->isa<tensor::Tensor>() && b_value_ptr->isa<tensor::Tensor>()) {
|
||||
auto a_tensor_ptr = a_value_ptr->cast<tensor::TensorPtr>();
|
||||
auto b_tensor_ptr = b_value_ptr->cast<tensor::TensorPtr>();
|
||||
if (a_tensor_ptr == nullptr || b_tensor_ptr == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Cast value node ptr fail.";
|
||||
}
|
||||
return a_tensor_ptr->ValueEqual(*b_tensor_ptr);
|
||||
} else {
|
||||
return (*a_value_ptr) == (*b_value_ptr);
|
||||
}
|
||||
}
|
||||
MS_LOG(DEBUG) << "check AnfNodePtr equal";
|
||||
}
|
||||
if (utils::isa<FuncGraphPtr>(a) && utils::isa<FuncGraphPtr>(b)) {
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/common/optimizer/inplace_node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
AnfNodePtr InplaceNodePass::Run(const FuncGraphPtr &, const AnfNodePtr &node) {
|
||||
std::vector<AnfNodePtr> pre_inputs;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
pre_inputs.insert(pre_inputs.end(), inputs.begin(), inputs.end());
|
||||
}
|
||||
bool ret = Process(node);
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto inputs = cnode->inputs();
|
||||
if (inputs.size() != pre_inputs.size()) {
|
||||
MS_LOG(EXCEPTION) << "InplaceNodePass ERROR, the pass modify node: " << node->DebugString()
|
||||
<< ", pass name: " << name();
|
||||
}
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(pre_inputs[i]);
|
||||
if (!opt::AnfEqual(inputs[i], pre_inputs[i])) {
|
||||
MS_LOG(EXCEPTION) << "InplaceNodePass ERROR, the pass modify node: " << node->DebugString()
|
||||
<< ", pass name: " << name() << ", before node " << i << ":" << inputs[i]->DebugString()
|
||||
<< ", after node " << i << ":" << pre_inputs[i]->DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ret) {
|
||||
return node;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_INPLACE_NODE_PASS_H
|
||||
#define MINDSPORE_INPLACE_NODE_PASS_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "utils/hash_map.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "backend/common/optimizer/pass_manager.h"
|
||||
#include "backend/common/optimizer/pattern_engine.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/optimizer/graph_optimizer.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BACKEND_EXPORT InplaceNodePass : public NodePass {
|
||||
public:
|
||||
explicit InplaceNodePass(const std::string &name = "") : NodePass(name) {}
|
||||
~InplaceNodePass() override = default;
|
||||
virtual bool Process(const AnfNodePtr &) const = 0;
|
||||
AnfNodePtr Run(const FuncGraphPtr &, const AnfNodePtr &node) override;
|
||||
bool IsFastPass() override { return true; }
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_INPLACE_NODE_PASS_H
|
|
@ -17,6 +17,9 @@
|
|||
|
||||
#include <deque>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <algorithm>
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/manager.h"
|
||||
|
@ -27,14 +30,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const size_t kSwitchBranchIndex = 2;
|
||||
const size_t kCallArgsIndex = 1;
|
||||
const size_t kPartialArgsIndex = 1;
|
||||
} // namespace
|
||||
|
||||
void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_node_fg,
|
||||
const FuncGraphPtr &sub_graph) {
|
||||
MS_EXCEPTION_IF_NULL(call_node);
|
||||
MS_EXCEPTION_IF_NULL(call_node_fg);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph->output());
|
||||
call_node->set_abstract(sub_graph->output()->abstract());
|
||||
auto manager = call_node_fg->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -58,9 +65,8 @@ void UpdateCallerAbstract(const AnfNodePtr &call_node, const FuncGraphPtr &call_
|
|||
}
|
||||
}
|
||||
|
||||
void AddOutputAndCallerToMap(
|
||||
const CNodePtr &cnode, const FuncGraphPtr &fg,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<std::pair<AnfNodePtr, FuncGraphPtr>>> *out_caller_map) {
|
||||
void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map, bool is_add) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(out_caller_map);
|
||||
auto inputs = cnode->inputs();
|
||||
|
@ -73,28 +79,39 @@ void AddOutputAndCallerToMap(
|
|||
}
|
||||
auto switch_subgraph = GetValueNode<FuncGraphPtr>(partial_inputs.at(kPartialArgsIndex));
|
||||
MS_EXCEPTION_IF_NULL(switch_subgraph);
|
||||
(*out_caller_map)[switch_subgraph->output()].emplace_back(cnode, fg);
|
||||
if (is_add) {
|
||||
(*out_caller_map)[switch_subgraph->output()].insert(cnode);
|
||||
UpdateCallerAbstract(cnode, fg, switch_subgraph);
|
||||
} else {
|
||||
(*out_caller_map)[switch_subgraph->output()].erase(cnode);
|
||||
}
|
||||
} else if (common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
auto call_subgraph = GetValueNode<FuncGraphPtr>(inputs.at(kCallArgsIndex));
|
||||
MS_EXCEPTION_IF_NULL(call_subgraph);
|
||||
(*out_caller_map)[call_subgraph->output()].emplace_back(cnode, fg);
|
||||
if (is_add) {
|
||||
(*out_caller_map)[call_subgraph->output()].insert(cnode);
|
||||
UpdateCallerAbstract(cnode, fg, call_subgraph);
|
||||
} else {
|
||||
(*out_caller_map)[call_subgraph->output()].erase(cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateSubGraphCaller(
|
||||
const AnfNodePtr &origin_output, const FuncGraphPtr &fg,
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<std::pair<AnfNodePtr, FuncGraphPtr>>> *out_caller_map) {
|
||||
void UpdateSubGraphCaller(const AnfNodePtr &origin_output, const FuncGraphPtr &fg,
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map,
|
||||
const mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> &node_to_fg) {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(fg->output());
|
||||
auto find_iter = (*out_caller_map).find(origin_output);
|
||||
if (find_iter != (*out_caller_map).end()) {
|
||||
auto call_node_list = find_iter->second;
|
||||
(*out_caller_map).erase(find_iter);
|
||||
for (auto &call_node_pair : call_node_list) {
|
||||
auto call_node = call_node_pair.first;
|
||||
auto call_node_fg = call_node_pair.second;
|
||||
for (auto &call_node : call_node_list) {
|
||||
auto fg_iter = node_to_fg.find(call_node);
|
||||
if (fg_iter == node_to_fg.end()) {
|
||||
MS_LOG(EXCEPTION) << "Node to Funcgraph find failed: " << call_node->fullname_with_scope();
|
||||
}
|
||||
auto call_node_fg = fg_iter->second.lock();
|
||||
UpdateCallerAbstract(call_node, call_node_fg, fg);
|
||||
}
|
||||
(*out_caller_map)[fg->output()] = call_node_list;
|
||||
|
@ -111,22 +128,169 @@ void SkipSameOp(const AnfNodePtr &old_node, const AnfNodePtr &new_node, mindspor
|
|||
}
|
||||
}
|
||||
|
||||
bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
||||
std::string GetCNodeKey(const AnfNodePtr &node) {
|
||||
auto primitive = GetCNodePrimitive(node);
|
||||
if (primitive != nullptr) {
|
||||
return primitive->name();
|
||||
} else {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
if (func_graph_index->has_gen_index()) {
|
||||
return;
|
||||
}
|
||||
|
||||
func_graph_index->set_has_gen_index(true);
|
||||
func_graph_index->node_to_fg_.clear();
|
||||
func_graph_index->node_degree_.clear();
|
||||
func_graph_index->name_to_cnode_.clear();
|
||||
func_graph_index->subgraph_out_caller_map_.clear();
|
||||
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
|
||||
// maybe call subgraph many times
|
||||
mindspore::HashMap<AnfNodePtr, std::vector<std::pair<AnfNodePtr, FuncGraphPtr>>> subgraph_out_caller_map = {};
|
||||
mindspore::HashSet<AnfNodePtr> seen_node;
|
||||
std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->output(), func_graph}};
|
||||
bool changes = false;
|
||||
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front().first;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto fg = todo.front().second;
|
||||
manager->AddFuncGraph(fg);
|
||||
todo.pop_front();
|
||||
|
||||
func_graph_index->node_to_fg_[node] = fg;
|
||||
auto degree_iter = func_graph_index->node_degree_.find(node);
|
||||
if (degree_iter == func_graph_index->node_degree_.end()) {
|
||||
func_graph_index->node_degree_[node] = 1;
|
||||
} else {
|
||||
degree_iter->second++;
|
||||
}
|
||||
if (node->isa<CNode>()) {
|
||||
func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
|
||||
}
|
||||
|
||||
if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
(void)seen_node.insert(node);
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(const_func_graph);
|
||||
if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
(void)todo.emplace_back(const_func_graph->output(), const_func_graph);
|
||||
}
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
ModifyOutputAndCallerToMap(cnode, fg, &func_graph_index->subgraph_out_caller_map_);
|
||||
auto inputs = cnode->inputs();
|
||||
(void)std::for_each(inputs.begin(), inputs.end(),
|
||||
[&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool NodePass::ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto iter = func_graph_index->node_to_fg_.find(node);
|
||||
if (iter == func_graph_index->node_to_fg_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Node to Funcgraph map can't find node: " << node->fullname_with_scope();
|
||||
}
|
||||
auto fg = iter->second.lock();
|
||||
TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info()));
|
||||
auto degree_iter = func_graph_index->node_degree_.find(node);
|
||||
if (degree_iter == func_graph_index->node_degree_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Node degree map can't find node: " << node->fullname_with_scope();
|
||||
}
|
||||
auto degree = degree_iter->second;
|
||||
if (degree == 0 && node != func_graph->output()) {
|
||||
return false;
|
||||
}
|
||||
// we may update return value in some pass.
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
auto origin_output = fg->output();
|
||||
MS_EXCEPTION_IF_NULL(origin_output);
|
||||
auto origin_abstract = origin_output->abstract();
|
||||
AnfNodePtr new_node = Run(fg, node);
|
||||
bool change = (new_node != nullptr);
|
||||
MS_EXCEPTION_IF_NULL(fg->output());
|
||||
if (origin_abstract != fg->output()->abstract()) {
|
||||
UpdateSubGraphCaller(origin_output, fg, &func_graph_index->subgraph_out_caller_map_, func_graph_index->node_to_fg_);
|
||||
}
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
(void)manager->Replace(node, new_node);
|
||||
// if replaced node is end_goto, refresh relative params in kernel graph
|
||||
auto kernel_graph = fg->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
if (kernel_graph != nullptr && node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto end_label = kernel_graph->get_end_goto();
|
||||
if (cnode == end_label && common::AnfAlgo::GetCNodeName(cnode) == kLabelSwitchOpName) {
|
||||
kernel_graph->set_end_goto(new_node->cast<CNodePtr>());
|
||||
}
|
||||
}
|
||||
AfterProcess(node, new_node, fg, func_graph_index);
|
||||
}
|
||||
return change;
|
||||
}
|
||||
|
||||
bool NodePass::ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
if (!func_graph_index->has_gen_index()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, func graph has not gen index, pass name: " << name();
|
||||
}
|
||||
auto src_pattern_root_name = GetPatternRootPrimitiveName();
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool changes = false;
|
||||
|
||||
std::vector<AnfNodePtr> cand_node;
|
||||
if (!src_pattern_root_name.empty()) {
|
||||
auto cnode_iter = func_graph_index->name_to_cnode_.find(src_pattern_root_name);
|
||||
if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
|
||||
return false;
|
||||
}
|
||||
std::copy(cnode_iter->second.begin(), cnode_iter->second.end(), std::back_inserter(cand_node));
|
||||
} else {
|
||||
for (const auto &kv : func_graph_index->name_to_cnode_) {
|
||||
std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(cand_node));
|
||||
}
|
||||
}
|
||||
for (const auto &node : cand_node) {
|
||||
auto change = ProcessFastPassNode(node, func_graph, func_graph_index, manager);
|
||||
changes = changes || change;
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
||||
bool NodePass::ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool changes = false;
|
||||
|
||||
// maybe call subgraph many times
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map = {};
|
||||
mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg = {};
|
||||
mindspore::HashSet<AnfNodePtr> seen_node;
|
||||
std::deque<std::pair<AnfNodePtr, FuncGraphPtr>> todo{{func_graph->output(), func_graph}};
|
||||
while (!todo.empty()) {
|
||||
AnfNodePtr node = todo.front().first;
|
||||
auto fg = todo.front().second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
manager->AddFuncGraph(fg);
|
||||
todo.pop_front();
|
||||
node_to_fg[node] = fg;
|
||||
if (seen_node.count(node) > 0 || !manager->all_nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -140,7 +304,7 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
AnfNodePtr new_node = Run(fg, node);
|
||||
bool change = (new_node != nullptr);
|
||||
if (origin_abstract != fg->output()->abstract()) {
|
||||
UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map);
|
||||
UpdateSubGraphCaller(origin_output, fg, &subgraph_out_caller_map, node_to_fg);
|
||||
}
|
||||
if (new_node != nullptr && new_node != node) {
|
||||
SkipSameOp(node, new_node, &seen_node);
|
||||
|
@ -171,15 +335,44 @@ bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
|||
}
|
||||
auto cnode = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
AddOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map);
|
||||
ModifyOutputAndCallerToMap(cnode, fg, &subgraph_out_caller_map);
|
||||
auto inputs = cnode->inputs();
|
||||
(void)std::for_each(inputs.begin(), inputs.end(), [&fg, &todo](AnfNodePtr &node) {
|
||||
(void)todo.emplace_back(std::pair<AnfNodePtr, FuncGraphPtr>(node, fg));
|
||||
});
|
||||
(void)std::for_each(inputs.begin(), inputs.end(),
|
||||
[&fg, &todo](AnfNodePtr &node) { (void)todo.emplace_back(node, fg); });
|
||||
}
|
||||
changes = changes || change;
|
||||
}
|
||||
return changes;
|
||||
}
|
||||
|
||||
bool NodePass::Run(const FuncGraphPtr &func_graph) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
FuncGraphManagerPtr manager = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
auto func_graph_index = manager->func_graph_index(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
|
||||
if (IsFastPass()) {
|
||||
MS_LOG(INFO) << "Run fast pass: " << name();
|
||||
GenIndex(func_graph, func_graph_index);
|
||||
return ProcessFastPass(func_graph, func_graph_index);
|
||||
}
|
||||
if (func_graph_index->has_gen_index()) {
|
||||
auto ret = MustExistPrimitiveName();
|
||||
for (const auto &primtive_name : ret) {
|
||||
auto cnode_iter = func_graph_index->name_to_cnode_.find(primtive_name);
|
||||
if (cnode_iter == func_graph_index->name_to_cnode_.end()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!ret.empty()) {
|
||||
MS_LOG(INFO) << "Skip pass fail, run pass: " << name();
|
||||
}
|
||||
}
|
||||
func_graph_index->set_has_gen_index(false);
|
||||
|
||||
return ProcessPass(func_graph, manager);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
#include "include/backend/visible.h"
|
||||
|
@ -28,10 +30,25 @@ class BACKEND_EXPORT NodePass : public Pass {
|
|||
public:
|
||||
explicit NodePass(const std::string &name) : Pass(name) {}
|
||||
~NodePass() override = default;
|
||||
virtual bool Run(const FuncGraphPtr &func_graph);
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
virtual bool IsFastPass() { return false; }
|
||||
virtual void AfterProcess(const AnfNodePtr &, const AnfNodePtr &, const FuncGraphPtr &, const FuncGraphIndexPtr &) {}
|
||||
virtual std::string GetPatternRootPrimitiveName() { return ""; }
|
||||
virtual AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) = 0;
|
||||
virtual std::vector<std::string> MustExistPrimitiveName() const { return {}; }
|
||||
|
||||
private:
|
||||
bool ProcessFastPassNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph,
|
||||
const FuncGraphIndexPtr &func_graph_index, const FuncGraphManagerPtr &manager);
|
||||
bool ProcessFastPass(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index);
|
||||
bool ProcessPass(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &manager);
|
||||
};
|
||||
using NodePassPtr = std::shared_ptr<NodePass>;
|
||||
void GenIndex(const FuncGraphPtr &func_graph, const FuncGraphIndexPtr &func_graph_index);
|
||||
void ModifyOutputAndCallerToMap(const CNodePtr &cnode, const FuncGraphPtr &fg,
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> *out_caller_map,
|
||||
bool is_add = true);
|
||||
std::string GetCNodeKey(const AnfNodePtr &node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_COMMON_NODE_PASS_H_
|
||||
|
|
|
@ -16,7 +16,10 @@
|
|||
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
#include "ir/manager.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -60,15 +63,19 @@ const std::vector<AnfNodePtr> &PatternMap::GetSeq(const std::string &name) const
|
|||
}
|
||||
|
||||
bool PatternMap::Emplace(const std::string &name, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
name_set_.insert(name);
|
||||
if (seq_map_.find(name) != seq_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Var Key: " << name << " should not be in SeqVarMap.";
|
||||
}
|
||||
|
||||
opt_scope_.insert(node);
|
||||
|
||||
auto iter = node_map_.find(name);
|
||||
if (iter == node_map_.end()) {
|
||||
node_map_.emplace(name, node);
|
||||
} else if (!opt::AnfEqual(node, iter->second)) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
MS_LOG(INFO) << "The value of key: " << name
|
||||
<< " is not equal to origin value, value: " + node->fullname_with_scope()
|
||||
<< " origin value: " << iter->second->fullname_with_scope();
|
||||
|
@ -83,6 +90,10 @@ bool PatternMap::Emplace(const std::string &name, const std::vector<AnfNodePtr>
|
|||
MS_LOG(EXCEPTION) << "SeqVar Key: " << name << " should not be in VarMap.";
|
||||
}
|
||||
|
||||
for (const auto &node : v) {
|
||||
opt_scope_.insert(node);
|
||||
}
|
||||
|
||||
auto iter = seq_map_.find(name);
|
||||
if (iter == seq_map_.end()) {
|
||||
seq_map_.emplace(name, v);
|
||||
|
@ -96,6 +107,8 @@ bool PatternMap::Emplace(const std::string &name, const std::vector<AnfNodePtr>
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < v.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(v[i]);
|
||||
MS_EXCEPTION_IF_NULL(origin_v[i]);
|
||||
if (!opt::AnfEqual(v[i], origin_v[i])) {
|
||||
MS_LOG(INFO) << "The value of key: " << name
|
||||
<< " is not equal to origin value, value: " + v[i]->fullname_with_scope()
|
||||
|
@ -181,6 +194,7 @@ BaseRef SrcPattern::GetRoot() const {
|
|||
|
||||
const Seq &GetSeq(const std::string &pattern_name, const std::string &node_name, const VarPtr &var,
|
||||
const EquivPtr &equiv) {
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto equiv_iter = equiv->find(var);
|
||||
if (equiv_iter == equiv->end()) {
|
||||
MS_LOG(EXCEPTION) << "The SeqVar Key: " << pattern_name << " is not in EquivMap, node name: " << node_name;
|
||||
|
@ -204,6 +218,7 @@ bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv
|
|||
MS_EXCEPTION_IF_CHECK_FAIL(seq.size() == IntToSize(0), "Match Failed, need zero seq, but get seq length: " +
|
||||
std::to_string(seq.size()) + ", node name: " + name);
|
||||
std::vector<AnfNodePtr> v;
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
if (!m_->Emplace(pattern_node.name_, v)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -214,6 +229,9 @@ bool SrcPattern::CheckEmptySeqVar(const std::string &name, const EquivPtr &equiv
|
|||
}
|
||||
|
||||
bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const EquivPtr &equiv) {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
auto input_iter = inputs_map_.find(name);
|
||||
if (input_iter == inputs_map_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Key: " << name << " is not a CNode.";
|
||||
|
@ -234,6 +252,8 @@ bool SrcPattern::match(const std::string &name, const AnfNodePtr &node, const Eq
|
|||
auto &match_node = cnode_inputs[now_match];
|
||||
if (pattern_node.type_ == "prim") {
|
||||
// prim
|
||||
MS_EXCEPTION_IF_NULL(pattern_node.p_);
|
||||
MS_EXCEPTION_IF_NULL(match_node);
|
||||
if (!opt::AnfEqual(pattern_node.p_, match_node)) {
|
||||
MS_LOG(EXCEPTION) << "The value of Primitive is not equal to matched value, pattern value: " +
|
||||
pattern_node.p_->ToString()
|
||||
|
@ -296,6 +316,7 @@ bool SrcPattern::build_pattern_map(const AnfNodePtr &node, const EquivPtr &equiv
|
|||
|
||||
DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list<PatternNode> &inputs,
|
||||
const BuildCNodeFunc &buildfunc) {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
if (fail_) {
|
||||
return *this;
|
||||
}
|
||||
|
@ -343,10 +364,12 @@ DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list
|
|||
<< ", CNode: " << name;
|
||||
}
|
||||
for (size_t i = 0; i < anf_inputs.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(anf_inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(cnode->input(i));
|
||||
if (!opt::AnfEqual(anf_inputs[i], cnode->input(i))) {
|
||||
MS_LOG(EXCEPTION) << "The actual input does not correspond to the input of the pattern, the input index: " << i
|
||||
<< ", actual input: " << anf_inputs[i]->fullname_with_scope()
|
||||
<< ", pattern input: " << new_node->cast<CNodePtr>()->input(i)->fullname_with_scope()
|
||||
<< ", actual input: " << anf_inputs[i]->DebugString()
|
||||
<< ", pattern input: " << new_node->cast<CNodePtr>()->input(i)->DebugString()
|
||||
<< ", CNode: " << name;
|
||||
}
|
||||
}
|
||||
|
@ -360,6 +383,7 @@ DstPattern &DstPattern::AddCNode(const string &name, const std::initializer_list
|
|||
}
|
||||
|
||||
DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &buildfunc) {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
if (fail_) {
|
||||
return *this;
|
||||
}
|
||||
|
@ -379,6 +403,7 @@ DstPattern &DstPattern::AddValueNode(const string &name, const BuildValueFunc &b
|
|||
}
|
||||
|
||||
void DstPattern::clear() {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
fail_ = false;
|
||||
root_ = nullptr;
|
||||
m_->Erase(dst_set_);
|
||||
|
@ -406,12 +431,28 @@ UnpackNode &UnpackNode::operator=(const std::string &name) {
|
|||
return *this;
|
||||
}
|
||||
|
||||
AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
AnfNodePtr PatternToPatternPass::GetSrcPatternRoot() {
|
||||
if (src_pattern_root_ == nullptr) {
|
||||
DefineSrcPattern(&src_pattern_);
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
src_pattern_root_ = SexpToNode(src_pattern_.GetRoot(), fg, primitive_vars_.get(), multigraph_);
|
||||
}
|
||||
return src_pattern_root_;
|
||||
}
|
||||
|
||||
std::string PatternToPatternPass::GetPatternRootPrimitiveName() {
|
||||
auto src_pattern_root = GetSrcPatternRoot();
|
||||
auto prim = GetCNodePrimitive(src_pattern_root);
|
||||
if (prim != nullptr) {
|
||||
return prim->name();
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (src_pattern_root_ == nullptr) {
|
||||
(void)GetSrcPatternRoot();
|
||||
}
|
||||
|
||||
auto primitive = GetCNodePrimitive(src_pattern_root_);
|
||||
if (IsPrimitiveCNode(node, primitive)) {
|
||||
|
@ -435,11 +476,217 @@ AnfNodePtr PatternToPatternPass::Run(const FuncGraphPtr &func_graph, const AnfNo
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
namespace {
|
||||
const auto kStageZero = 0;
|
||||
const auto kStageOne = 1;
|
||||
const auto kStageTwo = 2;
|
||||
|
||||
void DeleteCNode(const AnfNodePtr &node, const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
if (node->isa<CNode>()) {
|
||||
auto name_to_cnode_iter = func_graph_index->name_to_cnode_.find(GetCNodeKey(node));
|
||||
if (name_to_cnode_iter == func_graph_index->name_to_cnode_.end()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find cnode_name: "
|
||||
<< common::AnfAlgo::GetCNodeName(node);
|
||||
}
|
||||
auto &cnode_set = name_to_cnode_iter->second;
|
||||
auto cnode_set_iter = cnode_set.find(node);
|
||||
if (cnode_set_iter == cnode_set.end()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, name_to_cnode_ can't find node: " << node->fullname_with_scope();
|
||||
}
|
||||
cnode_set.erase(cnode_set_iter);
|
||||
ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), sub_graph, &func_graph_index->subgraph_out_caller_map_, false);
|
||||
}
|
||||
}
|
||||
|
||||
void AppendChild(const AnfNodePtr &node, const FuncGraphPtr &fg,
|
||||
std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> *anf_q) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(anf_q);
|
||||
if (IsValueNode<FuncGraph>(node)) {
|
||||
auto const_func_graph = GetValueNode<FuncGraphPtr>(node);
|
||||
MS_EXCEPTION_IF_NULL(const_func_graph);
|
||||
if (!const_func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
|
||||
anf_q->emplace(const_func_graph->output(), const_func_graph);
|
||||
}
|
||||
} else if (node->isa<CNode>()) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (const auto &input_node : cnode->inputs()) {
|
||||
anf_q->emplace(input_node, fg);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool DelSrcPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
|
||||
const mindspore::HashSet<AnfNodePtr> &opt_scope,
|
||||
std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
|
||||
const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(need_delete);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
auto node = top.first;
|
||||
auto fg = top.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (node != root) {
|
||||
auto degree_iter = func_graph_index->node_degree_.find(node);
|
||||
if (degree_iter == func_graph_index->node_degree_.end()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() << " not in degree map";
|
||||
}
|
||||
if (degree_iter->second <= 0) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
|
||||
<< " degree error, degree: " << degree_iter->second;
|
||||
}
|
||||
degree_iter->second--;
|
||||
if (degree_iter->second > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (opt_scope.find(node) == opt_scope.end()) {
|
||||
(*need_delete).insert({node, fg});
|
||||
return false;
|
||||
}
|
||||
|
||||
DeleteCNode(node, fg, func_graph_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AddDstPattern(const std::pair<AnfNodePtr, FuncGraphPtr> &top, const AnfNodePtr &root,
|
||||
const mindspore::HashSet<AnfNodePtr> &opt_scope,
|
||||
std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
|
||||
const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(root);
|
||||
MS_EXCEPTION_IF_NULL(need_delete);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
auto node = top.first;
|
||||
auto fg = top.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if (node->isa<CNode>()) {
|
||||
ModifyOutputAndCallerToMap(node->cast<CNodePtr>(), fg, &func_graph_index->subgraph_out_caller_map_);
|
||||
func_graph_index->name_to_cnode_[GetCNodeKey(node)].insert(node);
|
||||
func_graph_index->node_to_fg_[node] = fg;
|
||||
}
|
||||
|
||||
if (node != root) {
|
||||
auto degree_iter = func_graph_index->node_degree_.find(node);
|
||||
if (degree_iter == func_graph_index->node_degree_.end()) {
|
||||
func_graph_index->node_degree_[node] = 0;
|
||||
degree_iter = func_graph_index->node_degree_.find(node);
|
||||
}
|
||||
degree_iter->second++;
|
||||
if (degree_iter->second != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (opt_scope.find(node) == opt_scope.end()) {
|
||||
(*need_delete).erase({node, fg});
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool DelCascadeNode(const std::pair<AnfNodePtr, FuncGraphPtr> &top,
|
||||
std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete,
|
||||
const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(need_delete);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
auto node = top.first;
|
||||
auto fg = top.second;
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
if ((*need_delete).find({node, fg}) == (*need_delete).end()) {
|
||||
auto degree_iter = func_graph_index->node_degree_.find(node);
|
||||
if (degree_iter == func_graph_index->node_degree_.end()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope() << " not in degree map";
|
||||
}
|
||||
if (degree_iter->second <= 0) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, node: " << node->fullname_with_scope()
|
||||
<< " degree error, degree: " << degree_iter->second;
|
||||
}
|
||||
degree_iter->second--;
|
||||
if (degree_iter->second > 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
DeleteCNode(node, fg, func_graph_index);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BFS(const AnfNodePtr &root, const FuncGraphPtr &sub_graph, const mindspore::HashSet<AnfNodePtr> &opt_scope,
|
||||
std::set<std::pair<AnfNodePtr, FuncGraphPtr>> *need_delete, const FuncGraphIndexPtr &func_graph_index,
|
||||
size_t stage) {
|
||||
std::queue<std::pair<AnfNodePtr, FuncGraphPtr>> anf_q;
|
||||
|
||||
if (stage == kStageZero || stage == kStageOne) {
|
||||
anf_q.emplace(root, sub_graph);
|
||||
} else if (stage == kStageTwo) {
|
||||
for (const auto &p : (*need_delete)) {
|
||||
anf_q.push(p);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
|
||||
}
|
||||
|
||||
while (!anf_q.empty()) {
|
||||
auto top = anf_q.front();
|
||||
anf_q.pop();
|
||||
|
||||
bool ret = false;
|
||||
if (stage == kStageZero) {
|
||||
ret = DelSrcPattern(top, root, opt_scope, need_delete, func_graph_index);
|
||||
} else if (stage == kStageOne) {
|
||||
ret = AddDstPattern(top, root, opt_scope, need_delete, func_graph_index);
|
||||
} else if (stage == kStageTwo) {
|
||||
ret = DelCascadeNode(top, need_delete, func_graph_index);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Illegal BFS stage, expected stage is 0/1/2, but get stage: " << stage;
|
||||
}
|
||||
if (!ret) {
|
||||
continue;
|
||||
}
|
||||
|
||||
AppendChild(top.first, top.second, &anf_q);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void PatternToPatternPass::AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node,
|
||||
const FuncGraphPtr &sub_graph, const FuncGraphIndexPtr &func_graph_index) {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
MS_EXCEPTION_IF_NULL(old_node);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
MS_EXCEPTION_IF_NULL(sub_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph_index);
|
||||
std::set<std::pair<AnfNodePtr, FuncGraphPtr>> need_delete;
|
||||
auto &opt_scope = m_->GetOptScope();
|
||||
|
||||
auto old_node_iter = func_graph_index->node_degree_.find(old_node);
|
||||
if (old_node_iter == func_graph_index->node_degree_.end()) {
|
||||
MS_LOG(EXCEPTION) << "ProcessFastPass Error, old_node: " << old_node->fullname_with_scope() << " not in degree map";
|
||||
}
|
||||
auto origin_degree = old_node_iter->second;
|
||||
|
||||
func_graph_index->node_degree_[new_node] = origin_degree;
|
||||
func_graph_index->node_degree_[old_node] = 0;
|
||||
|
||||
BFS(old_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageZero);
|
||||
BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageOne);
|
||||
BFS(new_node, sub_graph, opt_scope, &need_delete, func_graph_index, kStageTwo);
|
||||
}
|
||||
|
||||
std::vector<UnpackNode> PatternToPatternPass::Unpacking(const std::string &s) {
|
||||
MS_EXCEPTION_IF_NULL(m_);
|
||||
auto v = m_->GetSeq(s);
|
||||
std::vector<UnpackNode> ret;
|
||||
std::transform(v.begin(), v.end(), std::back_inserter(ret), [](const AnfNodePtr &node) { return UnpackNode(node); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool PatternToPatternPass::IsFastPass() { return is_fast_pass_; }
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -43,11 +43,13 @@ class BACKEND_EXPORT PatternMap {
|
|||
void Clear();
|
||||
bool Check(const std::string &name, const AnfNodePtr &node) const;
|
||||
void Erase(const mindspore::HashSet<std::string> &del_set);
|
||||
const mindspore::HashSet<AnfNodePtr> &GetOptScope() const { return opt_scope_; }
|
||||
|
||||
private:
|
||||
mindspore::HashSet<std::string> name_set_;
|
||||
mindspore::HashMap<std::string, AnfNodePtr> node_map_;
|
||||
mindspore::HashMap<std::string, std::vector<AnfNodePtr>> seq_map_;
|
||||
mindspore::HashSet<AnfNodePtr> opt_scope_;
|
||||
};
|
||||
|
||||
using PatternMapPtr = std::shared_ptr<PatternMap>;
|
||||
|
@ -163,16 +165,22 @@ class BACKEND_EXPORT DstPattern {
|
|||
|
||||
class BACKEND_EXPORT PatternToPatternPass : public PatternPass {
|
||||
public:
|
||||
explicit PatternToPatternPass(const std::string &name = "", bool multigraph = true)
|
||||
explicit PatternToPatternPass(const std::string &name = "", bool is_fast_pass = false, bool multigraph = true)
|
||||
: PatternPass(name, multigraph),
|
||||
m_(std::make_shared<PatternMap>()),
|
||||
src_pattern_(SrcPattern(m_)),
|
||||
dst_pattern_(DstPattern(m_)) {}
|
||||
dst_pattern_(DstPattern(m_)),
|
||||
is_fast_pass_(is_fast_pass) {}
|
||||
~PatternToPatternPass() override = default;
|
||||
virtual void DefineSrcPattern(SrcPattern *src_pattern) = 0;
|
||||
virtual void DefineDstPattern(DstPattern *dst_pattern) = 0;
|
||||
virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const = 0;
|
||||
virtual bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const { return true; }
|
||||
bool IsFastPass() override;
|
||||
AnfNodePtr GetSrcPatternRoot();
|
||||
std::string GetPatternRootPrimitiveName() override;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
void AfterProcess(const AnfNodePtr &old_node, const AnfNodePtr &new_node, const FuncGraphPtr &sub_graph,
|
||||
const FuncGraphIndexPtr &func_graph_index) override;
|
||||
std::vector<UnpackNode> Unpacking(const std::string &s);
|
||||
|
||||
private:
|
||||
|
@ -180,6 +188,7 @@ class BACKEND_EXPORT PatternToPatternPass : public PatternPass {
|
|||
SrcPattern src_pattern_;
|
||||
DstPattern dst_pattern_;
|
||||
AnfNodePtr src_pattern_root_ = nullptr;
|
||||
bool is_fast_pass_;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
*/
|
||||
|
||||
#include "backend/common/pass/add_dropout_attrs.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
||||
|
@ -65,6 +69,12 @@ const AnfNodePtr AddDropoutAttrs::Process(const FuncGraphPtr &func_graph, const
|
|||
return cnode;
|
||||
}
|
||||
|
||||
std::vector<std::string> AddDropoutAttrs::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimDropout->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef AddDropoutAttrs::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimDropout, Xs});
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_DROPOUT_ATTRS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_ADD_DROPOUT_ATTRS_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -27,6 +29,9 @@ class AddDropoutAttrs : public PatternProcessPass {
|
|||
~AddDropoutAttrs() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,17 +21,16 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const AnfNodePtr AddDynamicShapeAttr::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
bool AddDynamicShapeAttr::Process(const AnfNodePtr &node) const {
|
||||
if (common::AnfAlgo::IsDynamicShape(node)) {
|
||||
auto func_graph = node->func_graph();
|
||||
MS_LOG(DEBUG) << "Set Dynamic Shape Attr to Node:" << node->fullname_with_scope();
|
||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
kernel_graph->SetGraphDynamicAttr(true);
|
||||
return true;
|
||||
}
|
||||
return node;
|
||||
return false;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,17 +16,13 @@
|
|||
|
||||
#ifndef MINDSPORE_ADD_DYNAMIC_SHAPE_ATTR_H
|
||||
#define MINDSPORE_ADD_DYNAMIC_SHAPE_ATTR_H
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/inplace_node_pass.h"
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AddDynamicShapeAttr : public PatternProcessPass {
|
||||
class AddDynamicShapeAttr : public InplaceNodePass {
|
||||
public:
|
||||
explicit AddDynamicShapeAttr(bool multigraph = true) : PatternProcessPass("add_dynamic_shape_attr", multigraph) {}
|
||||
~AddDynamicShapeAttr() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
AddDynamicShapeAttr() : InplaceNodePass("add_dynamic_shape_attr") {}
|
||||
bool Process(const AnfNodePtr &node) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -25,18 +25,12 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kCNodePrimitiveIdx = 0;
|
||||
}
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kMConv2dTrans = "m_conv2d_trans";
|
||||
constexpr auto kRConv2dBp = "r_conv2d_bp";
|
||||
|
||||
const BaseRef ConvTransposeToConvBackpropInputPass::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto conv_transpose = std::make_shared<Primitive>(kConv2DTransposeOpName);
|
||||
return VectorRef({conv_transpose, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr ConvTransposeToConvBackpropInputPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
AnfNodePtr BuildConv2DBackpropInput(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto node = m.Get(kMConv2dTrans);
|
||||
auto conv_transpose = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_transpose);
|
||||
|
||||
|
@ -51,5 +45,19 @@ const AnfNodePtr ConvTransposeToConvBackpropInputPass::Process(const FuncGraphPt
|
|||
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool ConvTransposeToConvBackpropInputPass::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &,
|
||||
const AnfNodePtr &) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void ConvTransposeToConvBackpropInputPass::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(kXs).AddCNode(kMConv2dTrans, {prim::kPrimConv2DTranspose, kXs});
|
||||
}
|
||||
|
||||
void ConvTransposeToConvBackpropInputPass::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRConv2dBp, {prim::kPrimConv2DBackpropInput, kXs}, BuildConv2DBackpropInput);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,17 +17,17 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONV_TRANSPOSE_TO_CONV_BP_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvTransposeToConvBackpropInputPass : public PatternProcessPass {
|
||||
class ConvTransposeToConvBackpropInputPass : public PatternToPatternPass {
|
||||
public:
|
||||
explicit ConvTransposeToConvBackpropInputPass(bool multigraph = true)
|
||||
: PatternProcessPass("conv_transpose_to_conv_backprop_input", multigraph) {}
|
||||
ConvTransposeToConvBackpropInputPass() : PatternToPatternPass("conv_transpose_to_conv_backprop_input", true) {}
|
||||
~ConvTransposeToConvBackpropInputPass() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,10 +23,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
bool ConvertAttrToUnifyMindIR::Process(const AnfNodePtr &node) const {
|
||||
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
|
@ -51,7 +50,7 @@ const AnfNodePtr ConvertAttrToUnifyMindIR::Process(const FuncGraphPtr &, const A
|
|||
}
|
||||
}
|
||||
|
||||
return node;
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,17 +16,15 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_ATTR_TO_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CONVERT_ATTR_TO_UNIFY_MINDIR_H_
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/inplace_node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvertAttrToUnifyMindIR : public PatternProcessPass {
|
||||
class ConvertAttrToUnifyMindIR : public InplaceNodePass {
|
||||
public:
|
||||
explicit ConvertAttrToUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("convert_attr_to_unify_mindir", multigraph) {}
|
||||
~ConvertAttrToUnifyMindIR() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
ConvertAttrToUnifyMindIR() : InplaceNodePass("convert_attr_to_unify_mindir") {}
|
||||
bool Process(const AnfNodePtr &node) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,22 +17,23 @@
|
|||
#include <memory>
|
||||
#include "backend/common/pass/convert_dynamic_broadcast_to.h"
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const AnfNodePtr ConvertDynamicBroadcastTo::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
namespace {
|
||||
const auto kV = "V";
|
||||
const auto kMBroadcastTo = "m_broadcast_to";
|
||||
const auto kRBroadcastTo = "r_broadcast_to";
|
||||
AnfNodePtr BuildDynamicBroadcastTo(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMBroadcastTo);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto node_name = common::AnfAlgo::GetCNodeName(node);
|
||||
if (node_name == prim::kPrimDynamicBroadcastTo->name() && !common::AnfAlgo::IsDynamicShape(node)) {
|
||||
auto broadcast_to_op_name = prim::kPrimBroadcastTo->name();
|
||||
auto ori_cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(ori_cnode);
|
||||
auto input_x = common::AnfAlgo::GetInputNode(ori_cnode, 0);
|
||||
auto func_graph = node->func_graph();
|
||||
CNodePtr broadcast_to_node =
|
||||
opt::NewCNode({NewValueNode(std::make_shared<Primitive>(broadcast_to_op_name)), input_x}, func_graph, {node});
|
||||
MS_EXCEPTION_IF_NULL(broadcast_to_node);
|
||||
|
@ -42,7 +43,23 @@ const AnfNodePtr ConvertDynamicBroadcastTo::Process(const FuncGraphPtr &func_gra
|
|||
common::AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(shape_ptr->shape()), broadcast_to_node);
|
||||
return broadcast_to_node;
|
||||
}
|
||||
return node;
|
||||
} // namespace
|
||||
|
||||
bool ConvertDynamicBroadcastTo::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!common::AnfAlgo::IsDynamicShape(node)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void ConvertDynamicBroadcastTo::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kV).AddCNode(kMBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV});
|
||||
}
|
||||
|
||||
void ConvertDynamicBroadcastTo::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRBroadcastTo, {prim::kPrimDynamicBroadcastTo, kV}, BuildDynamicBroadcastTo);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,18 +16,19 @@
|
|||
|
||||
#ifndef MINDSPORE_CONVERT_DYNAMIC_BROADCAST_TO_ATTR_H
|
||||
#define MINDSPORE_CONVERT_DYNAMIC_BROADCAST_TO_ATTR_H
|
||||
#include <string>
|
||||
#include "ir/anf.h"
|
||||
#include "include/common/utils/convert_utils.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ConvertDynamicBroadcastTo : public PatternProcessPass {
|
||||
class ConvertDynamicBroadcastTo : public PatternToPatternPass {
|
||||
public:
|
||||
explicit ConvertDynamicBroadcastTo(bool multigraph = true)
|
||||
: PatternProcessPass("convert_dynamic_broadcast_to", multigraph) {}
|
||||
ConvertDynamicBroadcastTo() : PatternToPatternPass("convert_dynamic_broadcast_to", true) {}
|
||||
~ConvertDynamicBroadcastTo() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -27,6 +27,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kMCustom = "m_custom";
|
||||
constexpr auto kRCustom = "r_custom";
|
||||
|
||||
void ParseAttrDefaultValue(const std::string &op_name, const std::string &attr_name, const std::string &attr_value,
|
||||
const std::string &attr_type, const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
@ -117,30 +121,14 @@ void AddMissingAttrs(const CNodePtr &cnode, kernel::OpImplyType imply_type,
|
|||
cnode->set_input(kAnfPrimitiveIndex, NewValueNode(primitive));
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
if (node == nullptr || !AnfUtils::IsRealCNodeKernel(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
AnfNodePtr BuildCustom(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto cnode = m.Get(kMCustom)->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (!IsPrimitiveCNode(cnode, prim::kPrimCustom)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto func_type = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType);
|
||||
// AKG/AICPU need to process attr, TBE will process later in the json creating phase.
|
||||
if (!IsOneOfCustomAkgType(func_type) || func_type == kCustomTypeAICPU) {
|
||||
return nullptr;
|
||||
}
|
||||
// Early return if current node does not have attr
|
||||
auto attr_names = primitive->GetAttr(kAttrAttrNames);
|
||||
if (attr_names == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
// Early return if all attr in reg info exist in the node's attr
|
||||
std::unordered_set<std::string> missing_attrs;
|
||||
auto attr_names_vec = GetValue<std::vector<std::string>>(attr_names);
|
||||
|
@ -156,7 +144,34 @@ const AnfNodePtr CustomOpRegInfoToAttr::Process(const FuncGraphPtr &, const AnfN
|
|||
func_type == kCustomTypeAICPU ? kernel::OpImplyType::kImplyAICPU : kernel::OpImplyType::kImplyAKG;
|
||||
AddMissingAttrs(cnode, imply_type, missing_attrs);
|
||||
|
||||
return node;
|
||||
return cnode;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CustomOpRegInfoToAttr::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto func_type = common::AnfAlgo::GetNodeAttr<std::string>(cnode, kAttrFuncType);
|
||||
// AKG/AICPU need to process attr, TBE will process later in the json creating phase.
|
||||
if (!IsOneOfCustomAkgType(func_type) || func_type == kCustomTypeAICPU) {
|
||||
return false;
|
||||
}
|
||||
// Early return if current node does not have attr
|
||||
auto attr_names = primitive->GetAttr(kAttrAttrNames);
|
||||
if (attr_names == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void CustomOpRegInfoToAttr::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(kXs).AddCNode(kMCustom, {prim::kPrimCustom, kXs});
|
||||
}
|
||||
|
||||
void CustomOpRegInfoToAttr::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRCustom, {prim::kPrimCustom, kXs}, BuildCustom);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -15,17 +15,18 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_REG_INFO_TO_ATTR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_CUSTOM_OP_REG_INFO_TO_ATTR_H_
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class CustomOpRegInfoToAttr : public PatternProcessPass {
|
||||
class CustomOpRegInfoToAttr : public PatternToPatternPass {
|
||||
public:
|
||||
explicit CustomOpRegInfoToAttr(bool multigraph = true)
|
||||
: PatternProcessPass("custom_op_reg_info_to_attr", multigraph) {}
|
||||
CustomOpRegInfoToAttr() : PatternToPatternPass("custom_op_reg_info_to_attr", true) {}
|
||||
~CustomOpRegInfoToAttr() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <utility>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
|
||||
|
@ -138,6 +139,12 @@ void ExpandFlattenConcatTupleInput(const FuncGraphPtr &graph, const CNodePtr &cn
|
|||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<std::string> FlattenConcatFission::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimFlattenConcat->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef FlattenConcatFission::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
return VectorRef({prim::kPrimFlattenConcat, Xs});
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FLATTEN_CONCAT_FISSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_FLATTEN_CONCAT_FISSION_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -26,6 +28,9 @@ class FlattenConcatFission : public PatternProcessPass {
|
|||
~FlattenConcatFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -100,6 +100,12 @@ CNodePtr InplaceAssignAfterTupleGetItem(const FuncGraphPtr &func_graph, const CN
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<std::string> InplaceAssignForCustomOp::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimCustom->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const AnfNodePtr InplaceAssignForCustomOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INPLACE_ASSIGN_FOR_CUSTOM_OP_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_PASS_INPLACE_ASSIGN_FOR_CUSTOM_OP_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
|
@ -26,6 +30,7 @@ class InplaceAssignForCustomOp : public PatternProcessPass {
|
|||
: PatternProcessPass("inplace_assign_for_custom_op", multigraph) {}
|
||||
~InplaceAssignForCustomOp() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
|
||||
private:
|
||||
mutable mindspore::HashSet<CNodePtr> visited_{};
|
||||
|
|
|
@ -19,10 +19,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto m_3d = "m_3d";
|
||||
constexpr auto V = "V";
|
||||
constexpr auto Xs = "Xs";
|
||||
constexpr auto r_3d = "r_3d";
|
||||
constexpr auto kM3d = "m_3d";
|
||||
constexpr auto kV = "V";
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kR3d = "r_3d";
|
||||
} // namespace
|
||||
|
||||
bool AddIoFormatAttrFor3DGraph::CheckMatchedDAG(const PatternMap &m, const FuncGraphPtr &graph,
|
||||
|
@ -36,7 +36,7 @@ bool AddIoFormatAttrFor3DGraph::CheckMatchedDAG(const PatternMap &m, const FuncG
|
|||
}
|
||||
|
||||
AnfNodePtr AddAttr(const PatternMap &m, const AnfNodePtr & /* default_cnode */) {
|
||||
auto node = m.Get(m_3d);
|
||||
auto node = m.Get(kM3d);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
|
||||
auto formats = AnfAlgo::GetAllOutputFormats(node);
|
||||
if (std::any_of(formats.begin(), formats.end(), [](const std::string &format) { return IsOneOf3DFormat(format); })) {
|
||||
|
@ -45,10 +45,10 @@ AnfNodePtr AddAttr(const PatternMap &m, const AnfNodePtr & /* default_cnode */)
|
|||
return node;
|
||||
}
|
||||
void AddIoFormatAttrFor3DGraph::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(void)(*src_pattern).AddVar(V, UnVisited).AddSeqVar(Xs).AddCNode(m_3d, {V, Xs});
|
||||
(void)(*src_pattern).AddVar(kV, UnVisited).AddSeqVar(kXs).AddCNode(kM3d, {kV, kXs});
|
||||
}
|
||||
void AddIoFormatAttrFor3DGraph::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(void)(*dst_pattern).AddCNode(r_3d, {V, Xs}, AddAttr);
|
||||
(void)(*dst_pattern).AddCNode(kR3d, {kV, kXs}, AddAttr);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto Xs = "Xs";
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto call_inline = "call_inline";
|
||||
constexpr auto new_call_inline = "new_call_inline";
|
||||
} // namespace
|
||||
|
@ -41,11 +41,11 @@ AnfNodePtr BuildCallInline(const PatternMap &m, const AnfNodePtr &) {
|
|||
}
|
||||
|
||||
void ReselectCallInlineFormat::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(Xs).AddCNode(call_inline, {prim::kPrimCallInline, Xs});
|
||||
(*src_pattern).AddSeqVar(kXs).AddCNode(call_inline, {prim::kPrimCallInline, kXs});
|
||||
}
|
||||
|
||||
void ReselectCallInlineFormat::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, Xs}, BuildCallInline);
|
||||
(*dst_pattern).AddCNode(new_call_inline, {prim::kPrimCallInline, kXs}, BuildCallInline);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -29,8 +29,8 @@ constexpr size_t kReduceInputNum = 2;
|
|||
constexpr size_t kAxisInputIndex = 2;
|
||||
constexpr auto r_reduce = "r_reduce";
|
||||
constexpr auto m_reduce = "m_reduce";
|
||||
constexpr auto Xs = "Xs";
|
||||
constexpr auto V = "V";
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kV = "V";
|
||||
constexpr auto v_axis = "axis";
|
||||
} // namespace
|
||||
|
||||
|
@ -128,13 +128,13 @@ AnfNodePtr BuildReduce(const PatternMap &m, const AnfNodePtr &) {
|
|||
}
|
||||
|
||||
void ReduceAxisUpdate::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(void)(*src_pattern).AddVar(V, IsReduce).AddSeqVar(Xs).AddCNode(m_reduce, {V, Xs});
|
||||
(void)(*src_pattern).AddVar(kV, IsReduce).AddSeqVar(kXs).AddCNode(m_reduce, {kV, kXs});
|
||||
}
|
||||
|
||||
void ReduceAxisUpdate::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
auto reduce_input = Unpacking(Xs);
|
||||
auto reduce_input = Unpacking(kXs);
|
||||
reduce_input[kAxisInputIndex - 1] = v_axis;
|
||||
(void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {V, reduce_input}, BuildReduce);
|
||||
(void)(*dst_pattern).AddValueNode(v_axis, BuildAxis).AddCNode(r_reduce, {kV, reduce_input}, BuildReduce);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,10 +21,10 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto m_reduce_min = "m_reduce_min";
|
||||
constexpr auto r_reduce_min1 = "r_reduce_min1";
|
||||
constexpr auto r_reduce_min2 = "r_reduce_min2";
|
||||
constexpr auto X = "X";
|
||||
constexpr auto kMReduceMin = "m_reduce_min";
|
||||
constexpr auto kRReduceMin1 = "r_reduce_min1";
|
||||
constexpr auto kRReduceMin2 = "r_reduce_min2";
|
||||
constexpr auto kX1 = "X1";
|
||||
|
||||
bool NeedOptimize(const TypeId &dtype, const ShapeVector &shape, const std::vector<int64_t> &axis) {
|
||||
if (dtype != kNumberTypeFloat32) {
|
||||
|
@ -129,7 +129,7 @@ bool ReduceMinFission::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &g
|
|||
}
|
||||
|
||||
AnfNodePtr BuildReduceMin1(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto cnode = m.Get(m_reduce_min)->cast<CNodePtr>();
|
||||
auto cnode = m.Get(kMReduceMin)->cast<CNodePtr>();
|
||||
CNodePtr reduce_min1 = InitReduceMin(default_node->cast<CNodePtr>(), cnode);
|
||||
auto shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
|
||||
auto dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0);
|
||||
|
@ -143,7 +143,7 @@ AnfNodePtr BuildReduceMin1(const PatternMap &m, const AnfNodePtr &default_node)
|
|||
}
|
||||
|
||||
AnfNodePtr BuildReduceMin2(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto cnode = m.Get(m_reduce_min)->cast<CNodePtr>();
|
||||
auto cnode = m.Get(kMReduceMin)->cast<CNodePtr>();
|
||||
CNodePtr reduce_min2 = InitReduceMin(default_node->cast<CNodePtr>(), cnode);
|
||||
reduce_min2->set_abstract(cnode->abstract());
|
||||
std::vector<int64_t> axis_last = {-1};
|
||||
|
@ -152,13 +152,13 @@ AnfNodePtr BuildReduceMin2(const PatternMap &m, const AnfNodePtr &default_node)
|
|||
}
|
||||
|
||||
void ReduceMinFission::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(void)(*src_pattern).AddVar(X).AddCNode(m_reduce_min, {prim::kPrimReduceMinD, X});
|
||||
(void)(*src_pattern).AddVar(kX1).AddCNode(kMReduceMin, {prim::kPrimReduceMinD, kX1});
|
||||
}
|
||||
|
||||
void ReduceMinFission::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(void)(*dst_pattern)
|
||||
.AddCNode(r_reduce_min1, {prim::kPrimReduceMinD, X}, BuildReduceMin1)
|
||||
.AddCNode(r_reduce_min2, {prim::kPrimReduceMinD, r_reduce_min1}, BuildReduceMin2);
|
||||
.AddCNode(kRReduceMin1, {prim::kPrimReduceMinD, kX1}, BuildReduceMin1)
|
||||
.AddCNode(kRReduceMin2, {prim::kPrimReduceMinD, kRReduceMin1}, BuildReduceMin2);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -24,12 +24,7 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
bool AICpuLibSelectPass::Process(const AnfNodePtr &node) const {
|
||||
static const std::set<std::string> kAICpuOpNames = {kDropoutGenMaskOpName,
|
||||
kEnvironCreateOpName,
|
||||
kEnvironSetOpName,
|
||||
|
@ -329,7 +324,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
static const std::string kCpuKernelSoName = "mindspore_cpu_kernels";
|
||||
|
||||
if (!node->isa<CNode>()) {
|
||||
return node;
|
||||
return false;
|
||||
}
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(node);
|
||||
if (kAICpuOpNames.find(kernel_name) != kAICpuOpNames.end()) {
|
||||
|
@ -339,7 +334,7 @@ const AnfNodePtr AICpuLibSelectPass::Process(const FuncGraphPtr &graph, const An
|
|||
common::AnfAlgo::SetNodeAttr(kAttrCustAicpu, MakeValue(kCpuKernelSoName), node);
|
||||
}
|
||||
|
||||
return node;
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,15 +16,14 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AICPU_LIB_SELECT_H_
|
||||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/inplace_node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AICpuLibSelectPass : public PatternProcessPass {
|
||||
class AICpuLibSelectPass : public InplaceNodePass {
|
||||
public:
|
||||
explicit AICpuLibSelectPass(bool multigraph = true) : PatternProcessPass("env_op_attr_update", multigraph) {}
|
||||
~AICpuLibSelectPass() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
AICpuLibSelectPass() : InplaceNodePass("env_op_attr_update") {}
|
||||
bool Process(const AnfNodePtr &node) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -181,6 +181,12 @@ CNodePtr AllToAllUnifyMindIR::CreateConcatNode(const FuncGraphPtr &graph, const
|
|||
return concat;
|
||||
}
|
||||
|
||||
std::vector<std::string> NeighborExchangeUnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimNeighborExchange->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef NeighborExchangeUnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchange, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
@ -193,6 +199,12 @@ const AnfNodePtr NeighborExchangeUnifyMindIR::Process(const FuncGraphPtr &graph,
|
|||
return node;
|
||||
}
|
||||
|
||||
std::vector<std::string> AllToAllUnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimAllToAll->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef AllToAllUnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimAllToAll, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_ALL_TO_ALL_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -28,6 +30,9 @@ class NeighborExchangeUnifyMindIR : public PatternProcessPass {
|
|||
~NeighborExchangeUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class AllToAllUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -41,6 +46,7 @@ class AllToAllUnifyMindIR : public PatternProcessPass {
|
|||
CNodePtr CreateSplitNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all) const;
|
||||
CNodePtr CreateAllToAllvNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &split) const;
|
||||
CNodePtr CreateConcatNode(const FuncGraphPtr &graph, const CNodePtr &all_to_all, const CNodePtr &all_to_all_v) const;
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,15 @@ constexpr size_t kAvgPoolGradInputNum = 3;
|
|||
constexpr size_t kShapeDimNum = 4;
|
||||
constexpr float kKernelMatrixInitNum = 1.0;
|
||||
constexpr size_t kFloat32Len = 4; // size of float32
|
||||
constexpr auto kX1 = "X1";
|
||||
constexpr auto kX2 = "X2";
|
||||
constexpr auto kG = "G";
|
||||
constexpr auto kXShapeVNode = "XShapeVNode";
|
||||
constexpr auto kMeanMatrixVNode = "MeanMatrixVNode";
|
||||
constexpr auto kKernelMatrixVNode = "KernelMatrixVNode";
|
||||
constexpr auto kMAvgPoolGrad = "m_avg_pool_grad";
|
||||
constexpr auto kRAvgPoolGrad = "r_avg_pool_grad";
|
||||
|
||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
return common::AnfAlgo::GetPrevNodeOutputInferShape(node, 0UL);
|
||||
|
@ -181,24 +190,27 @@ ValueNodePtr CreateKernelMatrixValueNode(const FuncGraphPtr &func_graph, const A
|
|||
kernel_graph->AddValueNodeToGraph(kernel_matrix_vnode);
|
||||
return kernel_matrix_vnode;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef AvgPoolGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X1 = std::make_shared<Var>();
|
||||
VarPtr X2 = std::make_shared<Var>();
|
||||
VarPtr G = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimAvgPoolGrad, X1, X2, G});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
class BuildXShapeVNode {
|
||||
public:
|
||||
BuildXShapeVNode() = default;
|
||||
AnfNodePtr operator()(const PatternMap &m) const {
|
||||
auto node = m.Get(kMAvgPoolGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum);
|
||||
auto x_shape = GetInputXShape(avgpool_grad);
|
||||
auto graph = avgpool_grad->func_graph();
|
||||
auto x_shape_vnode = CreateShapeValueNode(graph, x_shape);
|
||||
return x_shape_vnode;
|
||||
}
|
||||
};
|
||||
class BuildMeanMatrixVNode {
|
||||
public:
|
||||
BuildMeanMatrixVNode() = default;
|
||||
AnfNodePtr operator()(const PatternMap &m) const {
|
||||
auto node = m.Get(kMAvgPoolGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum);
|
||||
|
||||
auto x_shape = GetInputXShape(avgpool_grad);
|
||||
auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL);
|
||||
auto k_size = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(avgpool_grad, kAttrKernelSize);
|
||||
auto stride = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(avgpool_grad, kAttrStrides);
|
||||
auto prim = GetCNodePrimitive(avgpool_grad);
|
||||
|
@ -206,15 +218,35 @@ const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, cons
|
|||
int64_t pad_mode_value = 0;
|
||||
CheckAndConvertUtils::GetPadModEnumValue(prim->GetAttr(kAttrPadMode), &pad_mode_value, true);
|
||||
auto pad_mode = PadMode(pad_mode_value);
|
||||
auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL);
|
||||
|
||||
auto x_shape_vnode = CreateShapeValueNode(graph, x_shape);
|
||||
auto graph = avgpool_grad->func_graph();
|
||||
auto mean_matrix_vnode = CreateMeanMatrixValueNode(graph, node, x_shape, k_size, stride, pad_mode, x_dtype);
|
||||
return mean_matrix_vnode;
|
||||
}
|
||||
};
|
||||
class BuildKernelMatrixVNode {
|
||||
public:
|
||||
BuildKernelMatrixVNode() = default;
|
||||
AnfNodePtr operator()(const PatternMap &m) const {
|
||||
auto node = m.Get(kMAvgPoolGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum);
|
||||
auto k_size = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(avgpool_grad, kAttrKernelSize);
|
||||
auto x_dtype = common::AnfAlgo::GetPrevNodeOutputInferDataType(avgpool_grad, 0UL);
|
||||
auto x_shape = GetInputXShape(avgpool_grad);
|
||||
auto graph = avgpool_grad->func_graph();
|
||||
auto kernel_matrix_vnode = CreateKernelMatrixValueNode(graph, node, x_shape, k_size, x_dtype);
|
||||
return kernel_matrix_vnode;
|
||||
}
|
||||
};
|
||||
AnfNodePtr BuildAvgPoolGrad(const PatternMap &m, const AnfNodePtr &new_node) {
|
||||
auto node = m.Get(kMAvgPoolGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto avgpool_grad = CheckAnfNodeIfCNodeAndInputSize(node, kAvgPoolGradInputNum);
|
||||
|
||||
std::vector<AnfNodePtr> avgpool_grad_vm_inputs = {NewValueNode(std::make_shared<Primitive>(kAvgPoolGradOpName)),
|
||||
x_shape_vnode, avgpool_grad->input(3UL), mean_matrix_vnode,
|
||||
kernel_matrix_vnode};
|
||||
auto avgpool_grad_vm = NewCNode(avgpool_grad_vm_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(new_node);
|
||||
auto avgpool_grad_vm = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(avgpool_grad_vm);
|
||||
avgpool_grad_vm->set_scope(avgpool_grad->scope());
|
||||
avgpool_grad_vm->set_abstract(avgpool_grad->abstract());
|
||||
|
@ -228,5 +260,20 @@ const AnfNodePtr AvgPoolGradUnifyMindIR::Process(const FuncGraphPtr &graph, cons
|
|||
common::AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(output_names), avgpool_grad_vm);
|
||||
return avgpool_grad_vm;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void AvgPoolGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kX1).AddVar(kX2).AddVar(kG).AddCNode(kMAvgPoolGrad, {prim::kPrimAvgPoolGrad, kX1, kX2, kG});
|
||||
}
|
||||
|
||||
void AvgPoolGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kXShapeVNode, BuildXShapeVNode())
|
||||
.AddValueNode(kMeanMatrixVNode, BuildMeanMatrixVNode())
|
||||
.AddValueNode(kKernelMatrixVNode, BuildKernelMatrixVNode())
|
||||
.AddCNode(kRAvgPoolGrad,
|
||||
{std::make_shared<Primitive>(kAvgPoolGradOpName), kXShapeVNode, kG, kMeanMatrixVNode, kKernelMatrixVNode},
|
||||
BuildAvgPoolGrad);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,18 +16,18 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_AVG_POOL_GRAD_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class AvgPoolGradUnifyMindIR : public PatternProcessPass {
|
||||
class AvgPoolGradUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit AvgPoolGradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("avg_pool_grad_unify_mindir", multigraph) {}
|
||||
AvgPoolGradUnifyMindIR() : PatternToPatternPass("avg_pool_grad_unify_mindir", true) {}
|
||||
~AvgPoolGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -23,22 +23,25 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAttrUnifyIRPassed = "unifyir_passed";
|
||||
constexpr auto kX1 = "X1";
|
||||
constexpr auto kX2 = "X2";
|
||||
constexpr auto kX3 = "X3";
|
||||
constexpr auto kX4 = "X4";
|
||||
constexpr auto kX5 = "X5";
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kMBatchnormGrad = "m_batchnorm_grad";
|
||||
constexpr auto kRBatchnormGrad = "r_batchnorm_grad";
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr &graph,
|
||||
const CNodePtr &bn_grad_node) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr BuildBatchNormGrad(const PatternMap &m, const AnfNodePtr &new_node) {
|
||||
auto node = m.Get(kMBatchnormGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto bn_grad_node = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(bn_grad_node);
|
||||
size_t kBNGradInputNum = 6;
|
||||
const auto &bn_grad_node_inputs = bn_grad_node->inputs();
|
||||
CheckCNodeInputSize(bn_grad_node, kBNGradInputNum);
|
||||
std::vector<AnfNodePtr> bn_grad_inputs = {NewValueNode(std::make_shared<Primitive>(kBatchNormGradOpName)),
|
||||
bn_grad_node_inputs[kDim1],
|
||||
bn_grad_node_inputs[kDim2],
|
||||
bn_grad_node_inputs[kDim3],
|
||||
bn_grad_node_inputs[kDim4],
|
||||
bn_grad_node_inputs[kDim5]};
|
||||
auto new_bn_grad = NewCNode(bn_grad_inputs, graph);
|
||||
auto new_bn_grad = new_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(new_bn_grad);
|
||||
MS_EXCEPTION_IF_NULL(new_bn_grad);
|
||||
new_bn_grad->set_scope(bn_grad_node->scope());
|
||||
auto types = {common::AnfAlgo::GetOutputInferDataType(bn_grad_node, 0UL),
|
||||
|
@ -57,24 +60,33 @@ AnfNodePtr BatchNormGradUnifyMindIR::CreateNewBatchNormGrad(const FuncGraphPtr &
|
|||
return new_bn_grad;
|
||||
}
|
||||
|
||||
const BaseRef BatchNormGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kBatchNormGradOpName);
|
||||
return VectorRef({prim, Xs});
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchNormGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
bool BatchNormGradUnifyMindIR::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &func_graph,
|
||||
const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (common::AnfAlgo::HasNodeAttr(kAttrUnifyIRPassed, cnode) ||
|
||||
(func_graph->has_flag(kAttrMutableKernel) && !GetBoolAttr(cnode, kAttrIsTraining))) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
return CreateNewBatchNormGrad(func_graph, cnode);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BatchNormGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kX1)
|
||||
.AddVar(kX2)
|
||||
.AddVar(kX3)
|
||||
.AddVar(kX4)
|
||||
.AddVar(kX5)
|
||||
.AddSeqVar(kXs)
|
||||
.AddCNode(kMBatchnormGrad, {std::make_shared<Primitive>(kBatchNormGradOpName), kX1, kX2, kX3, kX4, kX5, kXs});
|
||||
}
|
||||
|
||||
void BatchNormGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddCNode(kRBatchnormGrad, {std::make_shared<Primitive>(kBatchNormGradOpName), kX1, kX2, kX3, kX4, kX5},
|
||||
BuildBatchNormGrad);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,19 +17,18 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_BN_GRAD_UNIFY_MINDIR_H_
|
||||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchNormGradUnifyMindIR : public PatternProcessPass {
|
||||
class BatchNormGradUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit BatchNormGradUnifyMindIR(bool multigraph = true) : PatternProcessPass("bn_grad_unify_mindir", multigraph) {}
|
||||
BatchNormGradUnifyMindIR() : PatternToPatternPass("bn_grad_unify_mindir", true) {}
|
||||
~BatchNormGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
AnfNodePtr CreateNewBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &bn_grad_node) const;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -46,6 +46,11 @@ constexpr int64_t kV3ShapeLimitSize = 1 << 30;
|
|||
constexpr size_t kDropoutGradInputTensorNum = 2;
|
||||
constexpr size_t kFloat16Len = 2; // size of float16
|
||||
constexpr size_t kInt64Len = 8; // size of int64
|
||||
constexpr auto kX1 = "X1";
|
||||
constexpr auto kX2 = "X2";
|
||||
constexpr auto kKeepProbValue = "KeepProbValue";
|
||||
constexpr auto kMDropoutGrad = "m_dropout_grad";
|
||||
constexpr auto kRDropoutDoMask = "r_dropout_do_mask";
|
||||
|
||||
TypeId GetInputXDataType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -303,6 +308,92 @@ void UpdateReturnNode(const FuncGraphPtr &graph, const AnfNodePtr &origin_node,
|
|||
g_output->set_abstract(abstract);
|
||||
graph->set_output(g_output);
|
||||
}
|
||||
|
||||
class BuildKeepProbValue {
|
||||
public:
|
||||
BuildKeepProbValue() = default;
|
||||
AnfNodePtr operator()(const PatternMap &m) const {
|
||||
auto node = m.Get(kMDropoutGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum);
|
||||
|
||||
auto func_graph = node->func_graph();
|
||||
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode);
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
|
||||
return keep_prob_value;
|
||||
}
|
||||
};
|
||||
|
||||
AnfNodePtr BuildDropoutDoMask(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMDropoutGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum);
|
||||
|
||||
auto func_graph = dropout_grad_cnode->func_graph();
|
||||
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode);
|
||||
auto grad_input_shape = GetDropoutInputShape(dropout_grad_cnode->input(kIndex1));
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
|
||||
auto use_v3 = WhetherUseDropoutV3(dropout_grad_cnode, grad_input_shape);
|
||||
|
||||
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
|
||||
// in that scene, need to be updated.
|
||||
auto mask_input = dropout_grad_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(mask_input);
|
||||
if (mask_input->isa<Parameter>()) {
|
||||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
// update kernel info
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
|
||||
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
|
||||
} else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) {
|
||||
auto mask_input_cnode = mask_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mask_input_cnode);
|
||||
auto tuple_input = mask_input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
if (IsValueNode<ValueTuple>(tuple_input)) {
|
||||
auto tuple_abstract = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sequence_abstract_ptr);
|
||||
// Dropout's outputs only have two elements.
|
||||
if (sequence_abstract_ptr->size() != kIndex2) {
|
||||
MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size();
|
||||
}
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
abs.push_back(sequence_abstract_ptr->elements()[0]);
|
||||
// modify mask abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
abs.push_back(mask_abstract);
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
tuple_input->set_abstract(new_abstract);
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDropoutDoMask
|
||||
auto do_mask_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);
|
||||
auto dropout_do_mask = CreateDropoutDoMaskCNode(func_graph, dropout_grad_cnode,
|
||||
{dropout_grad_cnode->input(kIndex1), mask_input, keep_prob_value},
|
||||
do_mask_abstract, use_v3);
|
||||
return dropout_do_mask;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef DropoutAndDropoutGradUnifyMindIR::DefinePattern() const {
|
||||
|
@ -433,6 +524,12 @@ const AnfNodePtr DropoutUnifyMindIR0::Process(const FuncGraphPtr &func_graph, co
|
|||
return tuple_cnode;
|
||||
}
|
||||
|
||||
std::vector<std::string> DropoutUnifyMindIR1::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimDropout->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef DropoutUnifyMindIR1::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimDropout, X});
|
||||
|
@ -477,80 +574,17 @@ const AnfNodePtr DropoutUnifyMindIR1::Process(const FuncGraphPtr &func_graph, co
|
|||
return make_tuple;
|
||||
}
|
||||
|
||||
const BaseRef DropoutGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
auto dropout_grad_prim = std::make_shared<Primitive>(kDropoutGradOpName);
|
||||
return VectorRef({dropout_grad_prim, X, Y});
|
||||
void DropoutGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kX1)
|
||||
.AddVar(kX2)
|
||||
.AddCNode(kMDropoutGrad, {std::make_shared<Primitive>(kDropoutGradOpName), kX1, kX2});
|
||||
}
|
||||
|
||||
const AnfNodePtr DropoutGradUnifyMindIR::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto dropout_grad_cnode = node->cast<CNodePtr>();
|
||||
CheckCNodeInputSize(dropout_grad_cnode, kDropoutGradInputTensorNum);
|
||||
|
||||
auto grad_input_type_id = GetInputXDataType(dropout_grad_cnode);
|
||||
auto grad_input_shape = GetDropoutInputShape(dropout_grad_cnode->input(kIndex1));
|
||||
auto keep_prob_value = CreateKeepPorbValueNode(func_graph, dropout_grad_cnode, grad_input_type_id);
|
||||
auto use_v3 = WhetherUseDropoutV3(dropout_grad_cnode, grad_input_shape);
|
||||
|
||||
// DropoutGrad may not in the same graph with Dropout in heterogeneous scene, and mask input which is a parameter
|
||||
// in that scene, need to be updated.
|
||||
auto mask_input = dropout_grad_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(mask_input);
|
||||
if (mask_input->isa<Parameter>()) {
|
||||
// update abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
// update kernel info
|
||||
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{kNumberTypeUInt8});
|
||||
kernel_build_info_builder->SetOutputsKernelObjectType({kernel::KernelObjectType::TENSOR});
|
||||
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), mask_input.get());
|
||||
} else if (IsPrimitiveCNode(mask_input, prim::kPrimTupleGetItem)) {
|
||||
auto mask_input_cnode = mask_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(mask_input_cnode);
|
||||
auto tuple_input = mask_input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(tuple_input);
|
||||
if (IsValueNode<ValueTuple>(tuple_input)) {
|
||||
auto tuple_abstract = tuple_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(tuple_abstract);
|
||||
abstract::AbstractSequencePtr sequence_abstract_ptr = tuple_abstract->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sequence_abstract_ptr);
|
||||
// Dropout's outputs only have two elements.
|
||||
if (sequence_abstract_ptr->size() != kIndex2) {
|
||||
MS_LOG(EXCEPTION) << "Dropout's outputs have more than two elements, " << sequence_abstract_ptr->size();
|
||||
}
|
||||
abstract::AbstractBasePtrList abs{};
|
||||
abs.push_back(sequence_abstract_ptr->elements()[0]);
|
||||
// modify mask abstract
|
||||
auto mask_abstract = mask_input->abstract();
|
||||
MS_EXCEPTION_IF_NULL(mask_abstract);
|
||||
auto grad_shape_vec = grad_input_shape->shape();
|
||||
auto mask_shape =
|
||||
use_v3 ? CalGenMaskV3OutputShape(grad_shape_vec, kNumberTypeUInt8) : CalGenMaskOutputShape(grad_shape_vec);
|
||||
mask_abstract = std::make_shared<abstract::AbstractTensor>(kUInt8, mask_shape);
|
||||
mask_input->set_abstract(mask_abstract);
|
||||
abs.push_back(mask_abstract);
|
||||
auto new_abstract = std::make_shared<abstract::AbstractTuple>(abs);
|
||||
tuple_input->set_abstract(new_abstract);
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDropoutDoMask
|
||||
auto do_mask_abstract =
|
||||
std::make_shared<abstract::AbstractTensor>(TypeIdToType(grad_input_type_id), grad_input_shape);
|
||||
auto dropout_do_mask = CreateDropoutDoMaskCNode(func_graph, dropout_grad_cnode,
|
||||
{dropout_grad_cnode->input(kIndex1), mask_input, keep_prob_value},
|
||||
do_mask_abstract, use_v3);
|
||||
return dropout_do_mask;
|
||||
void DropoutGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kKeepProbValue, BuildKeepProbValue())
|
||||
.AddCNode(kRDropoutDoMask, {std::make_shared<Primitive>(kDropoutDoMaskOpName), kX1, kX2, kKeepProbValue},
|
||||
BuildDropoutDoMask);
|
||||
}
|
||||
} // namespace mindspore::opt
|
||||
|
|
|
@ -17,7 +17,10 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_DROPOUT_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -49,15 +52,18 @@ class DropoutUnifyMindIR1 : public PatternProcessPass {
|
|||
~DropoutUnifyMindIR1() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class DropoutGradUnifyMindIR : public PatternProcessPass {
|
||||
class DropoutGradUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit DropoutGradUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("dropoutgrad_unify_mindir", multigraph) {}
|
||||
DropoutGradUnifyMindIR() : PatternToPatternPass("dropoutgrad_unify_mindir", true) {}
|
||||
~DropoutGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -27,6 +28,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::vector<std::string> FSEDecodeAdjust::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(std::make_shared<Primitive>(kFSEDecodeOpName)->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef FSEDecodeAdjust::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kFSEDecodeOpName);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FSE_DECODE_ADJUST_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
|
@ -28,6 +29,9 @@ class FSEDecodeAdjust : public PatternProcessPass {
|
|||
~FSEDecodeAdjust() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "utils/ms_context.h"
|
||||
|
@ -120,6 +121,13 @@ void MaxPool2MaxPoolWithArgmax::SetNodeAttrs(const CNodePtr &maxpool, const CNod
|
|||
common::AnfAlgo::SetNodeAttr(kAttrKernelSize, MakeValue(ksize), maxpool_grad_argmax);
|
||||
}
|
||||
|
||||
std::vector<std::string> MaxPool2MaxPoolWithArgmax::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimMaxPool->name());
|
||||
ret.emplace_back(prim::kPrimMaxPoolGrad->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef MaxPool2MaxPoolWithArgmax::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -36,6 +37,7 @@ class MaxPool2MaxPoolWithArgmax : public PatternProcessPass {
|
|||
const std::vector<AnfNodePtr> &maxpool_argmax_outputs) const;
|
||||
void SetNodeAttrs(const CNodePtr &maxpool, const CNodePtr &maxpool_grad, const CNodePtr &maxpool_argmax,
|
||||
const CNodePtr &maxpool_grad_argmax) const;
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -32,6 +32,15 @@ constexpr size_t kMaxPoolGradWithArgmaxInputTensorNum = 3;
|
|||
constexpr size_t kMaxPoolGradWithArgmaxInputNum = 4;
|
||||
constexpr size_t kMaxPoolWithArgmaxShape = 4;
|
||||
constexpr size_t kAlignBytes = 16;
|
||||
constexpr auto kX1 = "X1";
|
||||
constexpr auto kX2 = "X2";
|
||||
constexpr auto kMaxPoolIndex = "index0";
|
||||
constexpr auto kMMaxPool = "m_max_pool";
|
||||
constexpr auto kRMaxPool = "r_max_pool";
|
||||
constexpr auto kMMaxpoolWithArgmax = "m_maxpool_with_argmax";
|
||||
constexpr auto kMTupleGetitem0 = "m_tuple_getitem0";
|
||||
constexpr auto kMMaxpoolGradWithArgmax = "m_maxpool_grad_with_argmax";
|
||||
constexpr auto kRMaxpoolGradWithArgmax = "r_maxpool_grad_with_argmax";
|
||||
|
||||
bool IsC(const BaseRef &n) {
|
||||
if (utils::isa<AnfNodePtr>(n)) {
|
||||
|
@ -48,17 +57,9 @@ CNodePtr GetMaxPoolWithArgmax(const CNodePtr &maxpool_grad_with_argmax) {
|
|||
MS_EXCEPTION_IF_NULL(tuple_getitem0_anf);
|
||||
return tuple_getitem0_anf->cast<CNodePtr>();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef MaxPoolWithArgmaxUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimMaxPoolWithArgmax, X});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr BuildMaxPoolWithArgmax(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMMaxPool);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto maxpool_with_argmax = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maxpool_with_argmax);
|
||||
|
@ -85,19 +86,8 @@ const AnfNodePtr MaxPoolWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph
|
|||
return maxpool_with_argmax;
|
||||
}
|
||||
|
||||
const BaseRef MaxPoolGradWithArgmaxUnifyMindIR::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VarPtr Y = std::make_shared<Var>();
|
||||
VarPtr index0 = std::make_shared<CondVar>(IsC);
|
||||
VectorRef maxpool_with_argmax({prim::kPrimMaxPoolWithArgmax, X});
|
||||
VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, maxpool_with_argmax, index0});
|
||||
VectorRef maxpool_grad_with_argmax({prim::kPrimMaxPoolGradWithArgmax, X, Y, tuple_getitem0});
|
||||
return maxpool_grad_with_argmax;
|
||||
}
|
||||
|
||||
const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr BuildMaxPoolGradWithArgmax(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMMaxpoolGradWithArgmax);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto maxpool_grad_with_argmax = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(maxpool_grad_with_argmax);
|
||||
|
@ -122,5 +112,29 @@ const AnfNodePtr MaxPoolGradWithArgmaxUnifyMindIR::Process(const FuncGraphPtr &g
|
|||
|
||||
return maxpool_grad_with_argmax;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MaxPoolWithArgmaxUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kX1).AddCNode(kMMaxPool, {prim::kPrimMaxPoolWithArgmax, kX1});
|
||||
}
|
||||
|
||||
void MaxPoolWithArgmaxUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRMaxPool, {prim::kPrimMaxPoolWithArgmax, kX1}, BuildMaxPoolWithArgmax);
|
||||
}
|
||||
|
||||
void MaxPoolGradWithArgmaxUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kX1)
|
||||
.AddVar(kX2)
|
||||
.AddVar(kMaxPoolIndex, IsC)
|
||||
.AddCNode(kMMaxpoolWithArgmax, {prim::kPrimMaxPoolWithArgmax, kX1})
|
||||
.AddCNode(kMTupleGetitem0, {prim::kPrimTupleGetItem, kMMaxpoolWithArgmax, kMaxPoolIndex})
|
||||
.AddCNode(kMMaxpoolGradWithArgmax, {prim::kPrimMaxPoolGradWithArgmax, kX1, kX2, kMTupleGetitem0});
|
||||
}
|
||||
void MaxPoolGradWithArgmaxUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddCNode(kRMaxpoolGradWithArgmax, {prim::kPrimMaxPoolGradWithArgmax, kX1, kX2, kMTupleGetitem0},
|
||||
BuildMaxPoolGradWithArgmax);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,27 +16,27 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_MAXPOOL_WITH_ARGMAX_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MaxPoolWithArgmaxUnifyMindIR : public PatternProcessPass {
|
||||
class MaxPoolWithArgmaxUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit MaxPoolWithArgmaxUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("maxpool_with_argmax_unify_mindir", multigraph) {}
|
||||
MaxPoolWithArgmaxUnifyMindIR() : PatternToPatternPass("maxpool_with_argmax_unify_mindir", true) {}
|
||||
~MaxPoolWithArgmaxUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
};
|
||||
|
||||
class MaxPoolGradWithArgmaxUnifyMindIR : public PatternProcessPass {
|
||||
class MaxPoolGradWithArgmaxUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit MaxPoolGradWithArgmaxUnifyMindIR(bool multigraph = true)
|
||||
: PatternProcessPass("maxpool_grad_with_argmax_unify_mindir", multigraph) {}
|
||||
MaxPoolGradWithArgmaxUnifyMindIR() : PatternToPatternPass("maxpool_grad_with_argmax_unify_mindir", true) {}
|
||||
~MaxPoolGradWithArgmaxUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -912,6 +912,12 @@ CNodePtr NeighborExchangeV2GradUnifyMindIR::CreateSplitGradNodes(const FuncGraph
|
|||
return addn;
|
||||
}
|
||||
|
||||
std::vector<std::string> NeighborExchangeV2UnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimNeighborExchangeV2->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef NeighborExchangeV2UnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchangeV2, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
@ -929,9 +935,16 @@ const AnfNodePtr NeighborExchangeV2UnifyMindIR::Process(const FuncGraphPtr &grap
|
|||
return concat;
|
||||
}
|
||||
|
||||
std::vector<std::string> NeighborExchangeV2GradUnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimNeighborExchangeV2Grad->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef NeighborExchangeV2GradUnifyMindIR::DefinePattern() const {
|
||||
return VectorRef({prim::kPrimNeighborExchangeV2Grad, std::make_shared<SeqVar>()});
|
||||
}
|
||||
|
||||
const AnfNodePtr NeighborExchangeV2GradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
|
@ -49,6 +51,7 @@ class NeighborExchangeV2UnifyMindIR : public PatternProcessPass {
|
|||
const CNodePtr &all_to_all_v) const;
|
||||
CNodePtr CreateConcatNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2,
|
||||
const CNodePtr &all_to_all_v) const;
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -68,6 +71,7 @@ class NeighborExchangeV2GradUnifyMindIR : public PatternProcessPass {
|
|||
CNodePtr CreateSplitGradNodes(const FuncGraphPtr &graph, const CNodePtr &neighbor_exchange_v2_grad,
|
||||
const CNodePtr &all_to_all_v, const std::vector<CNodePtr> &split_nodes,
|
||||
const std::vector<int64_t> &split_num) const;
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
} // namespace opt
|
||||
|
|
|
@ -30,12 +30,27 @@ constexpr size_t kFtrlOutputNum = 3;
|
|||
constexpr size_t kMomentumOutputNum = 2;
|
||||
constexpr size_t kRMSPropOutputNum = 3;
|
||||
constexpr size_t kCenteredRMSPropOutputNum = 4;
|
||||
constexpr auto kOptVar = "var";
|
||||
constexpr auto kOptAccum = "accum";
|
||||
constexpr auto kOptLinear = "linear";
|
||||
constexpr auto kOptGrad = "grad";
|
||||
constexpr auto kOptLr = "lr";
|
||||
constexpr auto kOptL1 = "l1";
|
||||
constexpr auto kOptL2 = "l2";
|
||||
constexpr auto kOptLrPower = "lr_power";
|
||||
constexpr auto kOptU = "u";
|
||||
constexpr auto kOptIndex = "index";
|
||||
constexpr auto kMomentum = "momentum";
|
||||
constexpr auto kInputs = "inputs";
|
||||
constexpr auto kMg = "mg";
|
||||
constexpr auto kMs = "ms";
|
||||
constexpr auto kMom = "mom";
|
||||
constexpr auto kRho = "rho";
|
||||
constexpr auto kEpsilon = "epsilon";
|
||||
constexpr auto kMOptimizer = "m_optimizer";
|
||||
constexpr auto kRTupleGet = "r_tuple_get";
|
||||
|
||||
CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const size_t output_size,
|
||||
const PatternProcessPass &pass) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
bool CheckNode(const AnfNodePtr &node) {
|
||||
auto cnode_ptr = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
|
||||
|
@ -43,89 +58,122 @@ CNodePtr ProcessOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, const
|
|||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
|
||||
if (common::AnfAlgo::HasNodeAttr("optim_output_passed", cnode_ptr) && abstract->isa<abstract::AbstractTuple>()) {
|
||||
return nullptr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
AnfNodePtr BuildZero(const PatternMap &) { return NewValueNode(static_cast<int64_t>(0)); }
|
||||
} // namespace
|
||||
|
||||
AnfNodePtr BuildTupleGetFunc::operator()(const PatternMap &m, const AnfNodePtr &get_item) const {
|
||||
auto node = m.Get(kMOptimizer);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto cnode_ptr = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode_ptr);
|
||||
|
||||
auto abstract = cnode_ptr->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abstract);
|
||||
common::AnfAlgo::SetNodeAttr("optim_output_passed", MakeValue(true), cnode_ptr);
|
||||
|
||||
std::vector<AbstractBasePtr> abstract_list;
|
||||
for (size_t i = 0; i < output_size; i++) {
|
||||
for (size_t i = 0; i < output_size_; i++) {
|
||||
abstract_list.push_back(abstract->Clone());
|
||||
}
|
||||
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
||||
cnode_ptr->set_abstract(abstract_tuple);
|
||||
|
||||
auto index = NewValueNode(static_cast<int64_t>(0));
|
||||
auto get_item = pass.NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode_ptr, index}, graph);
|
||||
MS_EXCEPTION_IF_NULL(get_item);
|
||||
|
||||
get_item->set_abstract(abstract->Clone());
|
||||
return get_item;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef FtrlUnifyOutput::DefinePattern() const {
|
||||
VarPtr var = std::make_shared<Var>();
|
||||
VarPtr accum = std::make_shared<Var>();
|
||||
VarPtr linear = std::make_shared<Var>();
|
||||
VarPtr grad = std::make_shared<Var>();
|
||||
VarPtr lr = std::make_shared<Var>();
|
||||
VarPtr l1 = std::make_shared<Var>();
|
||||
VarPtr l2 = std::make_shared<Var>();
|
||||
VarPtr lr_power = std::make_shared<Var>();
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u});
|
||||
return pattern;
|
||||
bool FtrlUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const {
|
||||
return CheckNode(node);
|
||||
}
|
||||
|
||||
const AnfNodePtr FtrlUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kFtrlOutputNum, *this);
|
||||
void FtrlUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kOptVar)
|
||||
.AddVar(kOptAccum)
|
||||
.AddVar(kOptLinear)
|
||||
.AddVar(kOptGrad)
|
||||
.AddVar(kOptLr)
|
||||
.AddVar(kOptL1)
|
||||
.AddVar(kOptL2)
|
||||
.AddVar(kOptLrPower)
|
||||
.AddVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyFtrl, kOptVar, kOptAccum, kOptLinear, kOptGrad, kOptLr, kOptL1, kOptL2,
|
||||
kOptLrPower, kOptU});
|
||||
}
|
||||
|
||||
const BaseRef MomentumUnifyOutput::DefinePattern() const {
|
||||
VarPtr var = std::make_shared<Var>();
|
||||
VarPtr accum = std::make_shared<Var>();
|
||||
VarPtr lr = std::make_shared<Var>();
|
||||
VarPtr grad = std::make_shared<Var>();
|
||||
VarPtr momentum = std::make_shared<Var>();
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u});
|
||||
return pattern;
|
||||
void FtrlUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kOptIndex, BuildZero)
|
||||
.AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kFtrlOutputNum));
|
||||
}
|
||||
|
||||
const AnfNodePtr MomentumUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kMomentumOutputNum, *this);
|
||||
bool MomentumUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const {
|
||||
return CheckNode(node);
|
||||
}
|
||||
|
||||
const BaseRef RMSPropUnifyOutput::DefinePattern() const {
|
||||
VarPtr inputs = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyRMSProp, inputs});
|
||||
return pattern;
|
||||
void MomentumUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kOptVar)
|
||||
.AddVar(kOptAccum)
|
||||
.AddVar(kOptLr)
|
||||
.AddVar(kOptGrad)
|
||||
.AddVar(kMomentum)
|
||||
.AddVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyMomentum, kOptVar, kOptAccum, kOptLr, kOptGrad, kMomentum, kOptU});
|
||||
}
|
||||
|
||||
const AnfNodePtr RMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kRMSPropOutputNum, *this);
|
||||
void MomentumUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kOptIndex, BuildZero)
|
||||
.AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kMomentumOutputNum));
|
||||
}
|
||||
|
||||
const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const {
|
||||
VarPtr var = std::make_shared<Var>();
|
||||
VarPtr mg = std::make_shared<Var>();
|
||||
VarPtr ms = std::make_shared<Var>();
|
||||
VarPtr mom = std::make_shared<Var>();
|
||||
VarPtr grad = std::make_shared<Var>();
|
||||
VarPtr lr = std::make_shared<Var>();
|
||||
VarPtr rho = std::make_shared<Var>();
|
||||
VarPtr momentum = std::make_shared<Var>();
|
||||
VarPtr epsilon = std::make_shared<Var>();
|
||||
VarPtr u = std::make_shared<SeqVar>();
|
||||
VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u});
|
||||
return pattern;
|
||||
bool RMSPropUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const {
|
||||
return CheckNode(node);
|
||||
}
|
||||
|
||||
const AnfNodePtr CenteredRMSPropUnifyOutput::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
return ProcessOutput(graph, node, kCenteredRMSPropOutputNum, *this);
|
||||
void RMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(kInputs).AddCNode(kMOptimizer, {prim::kPrimApplyRMSProp, kInputs});
|
||||
}
|
||||
|
||||
void RMSPropUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kOptIndex, BuildZero)
|
||||
.AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex}, BuildTupleGetFunc(kRMSPropOutputNum));
|
||||
}
|
||||
|
||||
bool CenteredRMSPropUnifyOutput::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &,
|
||||
const AnfNodePtr &node) const {
|
||||
return CheckNode(node);
|
||||
}
|
||||
|
||||
void CenteredRMSPropUnifyOutput::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern)
|
||||
.AddVar(kOptVar)
|
||||
.AddVar(kMg)
|
||||
.AddVar(kMs)
|
||||
.AddVar(kMom)
|
||||
.AddVar(kOptGrad)
|
||||
.AddVar(kOptLr)
|
||||
.AddVar(kRho)
|
||||
.AddVar(kMomentum)
|
||||
.AddVar(kEpsilon)
|
||||
.AddVar(kOptU)
|
||||
.AddCNode(kMOptimizer, {prim::kPrimApplyCenteredRMSProp, kOptVar, kMg, kMs, kMom, kOptGrad, kOptLr, kRho, kMomentum,
|
||||
kEpsilon, kOptU});
|
||||
}
|
||||
|
||||
void CenteredRMSPropUnifyOutput::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern)
|
||||
.AddValueNode(kOptIndex, BuildZero)
|
||||
.AddCNode(kRTupleGet, {prim::kPrimTupleGetItem, kMOptimizer, kOptIndex},
|
||||
BuildTupleGetFunc(kCenteredRMSPropOutputNum));
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,42 +16,52 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_OPTIMIZER_UNIFY_OUTPUT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class FtrlUnifyOutput : public PatternProcessPass {
|
||||
class BuildTupleGetFunc {
|
||||
public:
|
||||
explicit FtrlUnifyOutput(bool multigraph = true) : PatternProcessPass("ftrl_unify_output", multigraph) {}
|
||||
explicit BuildTupleGetFunc(const size_t output_size) : output_size_(output_size) {}
|
||||
AnfNodePtr operator()(const PatternMap &m, const AnfNodePtr &get_item) const;
|
||||
size_t output_size_;
|
||||
};
|
||||
class FtrlUnifyOutput : public PatternToPatternPass {
|
||||
public:
|
||||
FtrlUnifyOutput() : PatternToPatternPass("ftrl_unify_output", true) {}
|
||||
~FtrlUnifyOutput() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class MomentumUnifyOutput : public PatternProcessPass {
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
class MomentumUnifyOutput : public PatternToPatternPass {
|
||||
public:
|
||||
explicit MomentumUnifyOutput(bool multigraph = true) : PatternProcessPass("momentum_unify_output", multigraph) {}
|
||||
MomentumUnifyOutput() : PatternToPatternPass("momentum_unify_output", true) {}
|
||||
~MomentumUnifyOutput() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class CenteredRMSPropUnifyOutput : public PatternProcessPass {
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
class CenteredRMSPropUnifyOutput : public PatternToPatternPass {
|
||||
public:
|
||||
explicit CenteredRMSPropUnifyOutput(bool multigraph = true)
|
||||
: PatternProcessPass("centered_rmsprop_unify_output", multigraph) {}
|
||||
CenteredRMSPropUnifyOutput() : PatternToPatternPass("centered_rmsprop_unify_output", true) {}
|
||||
~CenteredRMSPropUnifyOutput() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
};
|
||||
|
||||
class RMSPropUnifyOutput : public PatternProcessPass {
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
class RMSPropUnifyOutput : public PatternToPatternPass {
|
||||
public:
|
||||
explicit RMSPropUnifyOutput(bool multigraph = true) : PatternProcessPass("rmsprop_unify_output", multigraph) {}
|
||||
RMSPropUnifyOutput() : PatternToPatternPass("rmsprop_unify_output", true) {}
|
||||
~RMSPropUnifyOutput() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -28,6 +28,12 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
std::vector<std::string> QuantDTypeCastAdjust::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(std::make_shared<Primitive>(kQuantDTypeCastOpName)->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef QuantDTypeCastAdjust::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto prim = std::make_shared<Primitive>(kQuantDTypeCastOpName);
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_QUANT_DTYPE_CAST_ADJUST_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
|
@ -28,6 +29,9 @@ class QuantDTypeCastAdjust : public PatternProcessPass {
|
|||
~QuantDTypeCastAdjust() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -34,6 +34,10 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr size_t kSliceGradInputTensorNum = 4;
|
||||
constexpr size_t kSliceGradCangjieInputTensorNum = 2;
|
||||
constexpr auto kMSliceGrad = "m_slice_grad";
|
||||
constexpr auto kRPad = "r_pad";
|
||||
constexpr auto kX1 = "X1";
|
||||
constexpr auto kXs = "Xs";
|
||||
|
||||
std::vector<int64_t> GetInputXShape(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -47,19 +51,10 @@ std::vector<int64_t> GetTupleValue(const AnfNodePtr &node) {
|
|||
MS_EXCEPTION_IF_NULL(value_node->value());
|
||||
return GetValue<std::vector<int64_t>>(value_node->value());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
const BaseRef SliceGradUnifyMindIR::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
VectorRef slice_grad({std::make_shared<Primitive>("SliceGrad"), Xs});
|
||||
return slice_grad;
|
||||
}
|
||||
|
||||
const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr BuildPad(const PatternMap &m, const AnfNodePtr &pad) {
|
||||
auto node = m.Get(kMSliceGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto slice_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(slice_grad);
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(slice_grad);
|
||||
|
@ -68,9 +63,6 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
<< "] of node " + slice_grad->DebugString() + " is not equal to " << kSliceGradInputTensorNum
|
||||
<< " or " << kSliceGradCangjieInputTensorNum << trace::DumpSourceLines(node);
|
||||
}
|
||||
std::vector<AnfNodePtr> pad_inputs = {NewValueNode(std::make_shared<Primitive>(kPadDOpName)),
|
||||
slice_grad->input(kIndex1)};
|
||||
auto pad = NewCNode(pad_inputs, graph);
|
||||
MS_EXCEPTION_IF_NULL(pad);
|
||||
pad->set_scope(slice_grad->scope());
|
||||
pad->set_abstract(slice_grad->abstract());
|
||||
|
@ -80,12 +72,6 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
std::vector<int64_t> begins;
|
||||
std::vector<int64_t> sizes;
|
||||
if (input_num == kSliceGradInputTensorNum) {
|
||||
auto begin_value = GetValueNode(slice_grad->input(kIndex3));
|
||||
auto size_value = GetValueNode(slice_grad->input(kIndex4));
|
||||
if (IsDynamic(x_shape) || begin_value == nullptr || size_value == nullptr || !begin_value->isa<ValueSequence>() ||
|
||||
!size_value->isa<ValueSequence>()) {
|
||||
return nullptr;
|
||||
}
|
||||
begins = GetTupleValue(slice_grad->input(kIndex3));
|
||||
sizes = GetTupleValue(slice_grad->input(kIndex4));
|
||||
} else {
|
||||
|
@ -108,5 +94,31 @@ const AnfNodePtr SliceGradUnifyMindIR::Process(const FuncGraphPtr &graph, const
|
|||
|
||||
return pad;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SliceGradUnifyMindIR::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &node) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto slice_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(slice_grad);
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(slice_grad);
|
||||
auto x_shape = GetInputXShape(slice_grad);
|
||||
if (input_num == kSliceGradInputTensorNum) {
|
||||
auto begin_value = GetValueNode(slice_grad->input(kIndex3));
|
||||
auto size_value = GetValueNode(slice_grad->input(kIndex4));
|
||||
if (IsDynamic(x_shape) || begin_value == nullptr || size_value == nullptr || !begin_value->isa<ValueSequence>() ||
|
||||
!size_value->isa<ValueSequence>()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SliceGradUnifyMindIR::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kX1).AddSeqVar(kXs).AddCNode(kMSliceGrad, {std::make_shared<Primitive>("SliceGrad"), kX1, kXs});
|
||||
}
|
||||
|
||||
void SliceGradUnifyMindIR::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRPad, {std::make_shared<Primitive>(kPadDOpName), kX1}, BuildPad);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,17 +16,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SLICE_GRAD_UNIFY_MINDIR_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SLICE_GRAD_UNIFY_MINDIR_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SliceGradUnifyMindIR : public PatternProcessPass {
|
||||
class SliceGradUnifyMindIR : public PatternToPatternPass {
|
||||
public:
|
||||
explicit SliceGradUnifyMindIR(bool multigraph = true) : PatternProcessPass("slice_grad_unify_mindir", multigraph) {}
|
||||
SliceGradUnifyMindIR() : PatternToPatternPass("slice_grad_unify_mindir", true) {}
|
||||
~SliceGradUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,19 +31,14 @@ constexpr size_t kBlockShapeDimNum = 2;
|
|||
constexpr auto kAttrBlockShape = "block_shape";
|
||||
constexpr auto kAttrPaddings = "paddings";
|
||||
constexpr auto kAttrCrops = "crops";
|
||||
} // namespace
|
||||
|
||||
const BaseRef SpaceToBatchNDAttrUpdate::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimSpaceToBatchND, X});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
constexpr auto kV = "V";
|
||||
constexpr auto kMSpace = "m_space";
|
||||
constexpr auto kRSpace = "r_space";
|
||||
constexpr auto kMBatch = "m_batch";
|
||||
constexpr auto kRBatch = "r_batch";
|
||||
|
||||
AnfNodePtr BuildSpace(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto node = m.Get(kMSpace);
|
||||
auto block_shape = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape);
|
||||
if (block_shape.size() == kBlockShapeDimNum) {
|
||||
(void)block_shape.insert(block_shape.cbegin(), 1);
|
||||
|
@ -57,17 +52,8 @@ const AnfNodePtr SpaceToBatchNDAttrUpdate::Process(const FuncGraphPtr &graph, co
|
|||
return node;
|
||||
}
|
||||
|
||||
const BaseRef BatchToSpaceNDAttrUpdate::DefinePattern() const {
|
||||
VarPtr X = std::make_shared<Var>();
|
||||
VectorRef pattern({prim::kPrimBatchToSpaceND, X});
|
||||
return pattern;
|
||||
}
|
||||
|
||||
const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
AnfNodePtr BuildBatch(const PatternMap &m, const AnfNodePtr &default_node) {
|
||||
auto node = m.Get(kMBatch);
|
||||
auto block_shape = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(node, kAttrBlockShape);
|
||||
if (block_shape.size() == kBlockShapeDimNum) {
|
||||
(void)block_shape.insert(block_shape.cbegin(), 1);
|
||||
|
@ -80,5 +66,30 @@ const AnfNodePtr BatchToSpaceNDAttrUpdate::Process(const FuncGraphPtr &graph, co
|
|||
}
|
||||
return node;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SpaceToBatchNDAttrUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void SpaceToBatchNDAttrUpdate::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kV).AddCNode(kMSpace, {prim::kPrimSpaceToBatchND, kV});
|
||||
}
|
||||
|
||||
void SpaceToBatchNDAttrUpdate::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRSpace, {prim::kPrimSpaceToBatchND, kV}, BuildSpace);
|
||||
}
|
||||
|
||||
bool BatchToSpaceNDAttrUpdate::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
void BatchToSpaceNDAttrUpdate::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddVar(kV).AddCNode(kMBatch, {prim::kPrimBatchToSpaceND, kV});
|
||||
}
|
||||
|
||||
void BatchToSpaceNDAttrUpdate::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRBatch, {prim::kPrimBatchToSpaceND, kV}, BuildBatch);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -17,26 +17,26 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_SPACE_BATCH_ND_ATTR_UPDATE_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SpaceToBatchNDAttrUpdate : public PatternProcessPass {
|
||||
class SpaceToBatchNDAttrUpdate : public PatternToPatternPass {
|
||||
public:
|
||||
explicit SpaceToBatchNDAttrUpdate(bool multigraph = true)
|
||||
: PatternProcessPass("space_to_batch_nd_attr_update", multigraph) {}
|
||||
SpaceToBatchNDAttrUpdate() : PatternToPatternPass("space_to_batch_nd_attr_update", true) {}
|
||||
~SpaceToBatchNDAttrUpdate() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
|
||||
class BatchToSpaceNDAttrUpdate : public PatternProcessPass {
|
||||
class BatchToSpaceNDAttrUpdate : public PatternToPatternPass {
|
||||
public:
|
||||
explicit BatchToSpaceNDAttrUpdate(bool multigraph = true)
|
||||
: PatternProcessPass("batch_to_space_nd_attr_update", multigraph) {}
|
||||
BatchToSpaceNDAttrUpdate() : PatternToPatternPass("batch_to_space_nd_attr_update", true) {}
|
||||
~BatchToSpaceNDAttrUpdate() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -474,6 +474,12 @@ CNodePtr CreateMulInput(const FuncGraphPtr &graph, const CNodePtr &mul_node, con
|
|||
}
|
||||
} // namespace
|
||||
|
||||
std::vector<std::string> SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
|
@ -634,6 +640,13 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name());
|
||||
ret.emplace_back(prim::kPrimMul->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
|
@ -669,6 +682,14 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro
|
|||
return new_mul_node;
|
||||
}
|
||||
|
||||
std::vector<std::string> PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::MustExistPrimitiveName() const {
|
||||
std::vector<std::string> ret;
|
||||
ret.emplace_back(prim::kPrimSparseSoftmaxCrossEntropyWithLogits->name());
|
||||
ret.emplace_back(prim::kPrimCast->name());
|
||||
ret.emplace_back(prim::kPrimMul->name());
|
||||
return ret;
|
||||
}
|
||||
|
||||
const BaseRef PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2::DefinePattern() const {
|
||||
VarPtr x1 = std::make_shared<Var>();
|
||||
VarPtr x2 = std::make_shared<Var>();
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -31,6 +32,9 @@ class SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass
|
|||
~SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public PatternProcessPass {
|
||||
|
@ -67,6 +71,9 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR : public Patter
|
|||
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public PatternProcessPass {
|
||||
|
@ -76,6 +83,9 @@ class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2 : public Patt
|
|||
~PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV2() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
private:
|
||||
std::vector<std::string> MustExistPrimitiveName() const override;
|
||||
};
|
||||
|
||||
class PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIRV3 : public PatternProcessPass {
|
||||
|
|
|
@ -24,21 +24,18 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
const BaseRef StridedSliceGradUpdateInputNames::DefinePattern() const {
|
||||
VarPtr Xs = std::make_shared<SeqVar>();
|
||||
auto strided_slice_grad_prim = std::make_shared<Primitive>(kStridedSliceGradOpName);
|
||||
return VectorRef({strided_slice_grad_prim, Xs});
|
||||
}
|
||||
namespace {
|
||||
constexpr auto kXs = "Xs";
|
||||
constexpr auto kMSliceGrad = "m_slice_grad";
|
||||
constexpr auto kRSliceGrad = "r_slice_grad";
|
||||
|
||||
const AnfNodePtr StridedSliceGradUpdateInputNames::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
AnfNodePtr BuildSliceGrad(const PatternMap &m, const AnfNodePtr &) {
|
||||
auto node = m.Get(kMSliceGrad);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto strided_slice_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_grad);
|
||||
|
||||
const size_t shapex_index = 1;
|
||||
if (common::AnfAlgo::IsDynamicShape(strided_slice_grad)) {
|
||||
auto primitive = common::AnfAlgo::GetCNodePrimitive(strided_slice_grad);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_names_ptr = primitive->GetAttr(kAttrInputNames);
|
||||
|
@ -46,8 +43,26 @@ const AnfNodePtr StridedSliceGradUpdateInputNames::Process(const FuncGraphPtr &g
|
|||
auto input_names_vec = GetValue<std::vector<std::string>>(input_names_ptr);
|
||||
input_names_vec[shapex_index] = "shape";
|
||||
common::AnfAlgo::SetNodeAttr(kAttrInputNames, MakeValue(input_names_vec), strided_slice_grad);
|
||||
return strided_slice_grad;
|
||||
}
|
||||
return nullptr;
|
||||
} // namespace
|
||||
|
||||
bool StridedSliceGradUpdateInputNames::CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &,
|
||||
const AnfNodePtr &node) const {
|
||||
auto strided_slice_grad = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(strided_slice_grad);
|
||||
if (common::AnfAlgo::IsDynamicShape(strided_slice_grad)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void StridedSliceGradUpdateInputNames::DefineSrcPattern(SrcPattern *src_pattern) {
|
||||
(*src_pattern).AddSeqVar(kXs).AddCNode(kMSliceGrad, {std::make_shared<Primitive>(kStridedSliceGradOpName), kXs});
|
||||
}
|
||||
|
||||
void StridedSliceGradUpdateInputNames::DefineDstPattern(DstPattern *dst_pattern) {
|
||||
(*dst_pattern).AddCNode(kRSliceGrad, {std::make_shared<Primitive>(kStridedSliceGradOpName), kXs}, BuildSliceGrad);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,18 +16,19 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_UPDATE_INPUT_NAMES_STRIDED_SLICE_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_MINDIR_UPDATE_INPUT_NAMES_STRIDED_SLICE_GRAD_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/common/optimizer/optimizer.h"
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class StridedSliceGradUpdateInputNames : public PatternProcessPass {
|
||||
class StridedSliceGradUpdateInputNames : public PatternToPatternPass {
|
||||
public:
|
||||
explicit StridedSliceGradUpdateInputNames(bool multigraph = true)
|
||||
: PatternProcessPass("update_strided_slice_grad_input_names", multigraph) {}
|
||||
StridedSliceGradUpdateInputNames() : PatternToPatternPass("update_strided_slice_grad_input_names", true) {}
|
||||
~StridedSliceGradUpdateInputNames() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override;
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override;
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -211,6 +211,7 @@ FuncGraphManager::FuncGraphManager(const std::vector<FuncGraphPtr> &roots, bool
|
|||
|
||||
void FuncGraphManager::Reset() {
|
||||
func_graphs_ = FuncGraphSet();
|
||||
func_graphs_index_ = FuncGraphIndexMap();
|
||||
all_nodes_ = AnfNodeSet();
|
||||
node_users_ = NodeUsersMap();
|
||||
signals_ = std::make_shared<Signals>();
|
||||
|
@ -285,6 +286,14 @@ FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) c
|
|||
return func_graphs_used_total_->func_graph_used_total_analysis()[fg];
|
||||
}
|
||||
|
||||
const FuncGraphIndexPtr &FuncGraphManager::func_graph_index(const FuncGraphPtr &fg) const {
|
||||
auto iter = func_graphs_index_.find(fg);
|
||||
if (iter == func_graphs_index_.end()) {
|
||||
MS_LOG(EXCEPTION) << "Func graph: " << fg->ToString() << " is not add FuncGraphIndexMap.";
|
||||
}
|
||||
return func_graphs_index_.at(fg);
|
||||
}
|
||||
|
||||
bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const {
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_EXCEPTION_IF_NULL(recursive_);
|
||||
|
@ -346,6 +355,8 @@ void FuncGraphManager::AddFuncGraph(const FuncGraphPtr &func_graph, bool is_root
|
|||
(void)new_nodes.emplace_back(std::move(return_node));
|
||||
}
|
||||
|
||||
func_graphs_index_.emplace(func_graph, std::make_shared<FuncGraphPassIndex>());
|
||||
|
||||
// Acquire all nodes from func_graph.
|
||||
AcquireNodes(std::move(new_nodes));
|
||||
}
|
||||
|
@ -362,6 +373,7 @@ void FuncGraphManager::Clear() noexcept {
|
|||
}
|
||||
|
||||
func_graphs_.clear();
|
||||
func_graphs_index_.clear();
|
||||
all_nodes_.clear();
|
||||
node_users_.clear();
|
||||
roots_.clear();
|
||||
|
|
|
@ -53,10 +53,13 @@ using ChangePtr = std::unique_ptr<Change>;
|
|||
|
||||
class FuncGraphTransaction;
|
||||
class FuncGraphManager;
|
||||
class FuncGraphPassIndex;
|
||||
using FuncGraphManagerPtr = std::shared_ptr<FuncGraphManager>;
|
||||
using FuncGraphIndexPtr = std::shared_ptr<FuncGraphPassIndex>;
|
||||
|
||||
using AnfNodeIndexSet = CompactSet<std::pair<AnfNodePtr, int>>;
|
||||
using NodeUsersMap = mindspore::HashMap<AnfNodePtr, AnfNodeIndexSet, PointerHash<AnfNodePtr>>;
|
||||
using FuncGraphIndexMap = mindspore::HashMap<FuncGraphPtr, FuncGraphIndexPtr>;
|
||||
|
||||
using FuncGraphSetPair = std::pair<FuncGraphPtr, FuncGraphSet>;
|
||||
using FuncGraphSetPtr = std::shared_ptr<FuncGraphSet>;
|
||||
|
@ -80,6 +83,21 @@ using CNodeIndexPair = std::pair<AnfNodePtr, int>;
|
|||
using CNodeIndexPairPtr = std::shared_ptr<CNodeIndexPair>;
|
||||
using FuncGraphToFuncGraphSetMap = OrderedMap<FuncGraphPtr, FuncGraphSet>;
|
||||
|
||||
// For Fast Pass
|
||||
class FuncGraphPassIndex {
|
||||
public:
|
||||
FuncGraphPassIndex() : has_gen_index_(false) {}
|
||||
void set_has_gen_index(bool is_gen_index) { has_gen_index_ = is_gen_index; }
|
||||
bool has_gen_index() const { return has_gen_index_; }
|
||||
mindspore::HashMap<AnfNodePtr, FuncGraphWeakPtr> node_to_fg_;
|
||||
mindspore::HashMap<std::string, std::set<AnfNodePtr>> name_to_cnode_;
|
||||
mindspore::HashMap<AnfNodePtr, std::set<AnfNodePtr>> subgraph_out_caller_map_;
|
||||
mindspore::HashMap<AnfNodePtr, size_t> node_degree_;
|
||||
|
||||
private:
|
||||
bool has_gen_index_;
|
||||
};
|
||||
|
||||
// analysis base class, graphs analysis which need dynamic compute by DepCollector in each read
|
||||
class DepComputer {
|
||||
public:
|
||||
|
@ -331,6 +349,8 @@ class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this<FuncGra
|
|||
|
||||
FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const;
|
||||
|
||||
const FuncGraphIndexPtr &func_graph_index(const FuncGraphPtr &fg) const;
|
||||
|
||||
bool recursive(const FuncGraphPtr &fg) const;
|
||||
std::shared_ptr<std::list<FuncGraphPtr>> recursive_graphs(const FuncGraphPtr &fg) const;
|
||||
|
||||
|
@ -361,6 +381,7 @@ class MS_CORE_API FuncGraphManager : public std::enable_shared_from_this<FuncGra
|
|||
|
||||
FuncGraphSet roots_; // Managed roots.
|
||||
FuncGraphSet func_graphs_; // Managed func graphs.
|
||||
FuncGraphIndexMap func_graphs_index_; // For Fast Pass
|
||||
|
||||
std::shared_ptr<Signals> signals_;
|
||||
|
||||
|
|
|
@ -132,7 +132,7 @@ bool FP32Imm::operator==(const Value &other) const {
|
|||
}
|
||||
}
|
||||
bool FP32Imm::operator==(const FP32Imm &other) const {
|
||||
if (std::isinf(v_) && std::isinf(other.v_)) {
|
||||
if ((std::isinf(v_) && std::isinf(other.v_)) || (std::isnan(v_) && std::isnan(other.v_))) {
|
||||
return true;
|
||||
}
|
||||
return fabs(v_ - other.v_) < DBL_EPSILON;
|
||||
|
@ -186,7 +186,7 @@ std::string ValueSequence::DumpText() const {
|
|||
}
|
||||
|
||||
bool FP64Imm::operator==(const FP64Imm &other) const {
|
||||
if (std::isinf(v_) && std::isinf(other.v_)) {
|
||||
if ((std::isinf(v_) && std::isinf(other.v_)) || (std::isnan(v_) && std::isnan(other.v_))) {
|
||||
return true;
|
||||
}
|
||||
return fabs(v_ - other.v_) < DBL_EPSILON;
|
||||
|
|
|
@ -0,0 +1,617 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pattern_to_pattern_pass_utils.h"
|
||||
#include "backend/common/optimizer/node_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
const auto kZero = 0;
|
||||
const auto kOne = 1;
|
||||
const auto kTwo = 2;
|
||||
const auto kThree = 3;
|
||||
|
||||
const auto kA = "a";
|
||||
const auto kB = "b";
|
||||
const auto kC = "c";
|
||||
const auto kD = "d";
|
||||
const auto kE = "e";
|
||||
const auto kAAddB = "a_add_b";
|
||||
const auto kCAddD = "c_add_d";
|
||||
const auto kMul = "mul";
|
||||
const auto kAdd = "add";
|
||||
|
||||
class TestFastMul0 : public PatternToPatternPass {
|
||||
// a*b + a*c -> a*(b+c)
|
||||
public:
|
||||
explicit TestFastMul0() : PatternToPatternPass("test_fast_mul0") {}
|
||||
~TestFastMul0() override = default;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override {
|
||||
(*src_pattern)
|
||||
.AddVar("a")
|
||||
.AddVar("b")
|
||||
.AddVar("c")
|
||||
.AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
|
||||
.AddCNode("ac", {std::make_shared<Primitive>(kMulOpName), "a", "c"})
|
||||
.AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "ac"});
|
||||
}
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override {
|
||||
(*dst_pattern)
|
||||
.AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
|
||||
.AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
|
||||
}
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
|
||||
};
|
||||
|
||||
class TestFastMul1 : public PatternToPatternPass {
|
||||
// a*b + c*d -> a*c
|
||||
public:
|
||||
explicit TestFastMul1() : PatternToPatternPass("test_fast_mul1") {}
|
||||
~TestFastMul1() override = default;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override {
|
||||
(*src_pattern)
|
||||
.AddVar("a")
|
||||
.AddVar("b")
|
||||
.AddVar("c")
|
||||
.AddVar("d")
|
||||
.AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "a", "b"})
|
||||
.AddCNode("cd", {std::make_shared<Primitive>(kMulOpName), "c", "d"})
|
||||
.AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "ab", "cd"});
|
||||
}
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override {
|
||||
(*dst_pattern).AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"});
|
||||
}
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
|
||||
};
|
||||
|
||||
class TestFastMul2 : public PatternToPatternPass {
|
||||
// a*b -> b*a
|
||||
public:
|
||||
explicit TestFastMul2() : PatternToPatternPass("test_fast_mul2") {}
|
||||
~TestFastMul2() override = default;
|
||||
|
||||
void DefineSrcPattern(SrcPattern *src_pattern) override {
|
||||
(*src_pattern).AddSeqVar("Sv").AddCNode("ab", {std::make_shared<Primitive>(kMulOpName), "Sv"});
|
||||
}
|
||||
void DefineDstPattern(DstPattern *dst_pattern) override {
|
||||
auto ba = Unpacking("Sv");
|
||||
auto ab = Unpacking("Sv");
|
||||
ba[0] = ab[1];
|
||||
ba[1] = ab[0];
|
||||
(*dst_pattern).AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), ba});
|
||||
}
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
class TestFastPatternToPatternPass : public UT::Common {
|
||||
public:
|
||||
TestFastPatternToPatternPass() : fg_(std::make_shared<FuncGraph>()){};
|
||||
|
||||
public:
|
||||
FuncGraphPtr fg_;
|
||||
};
|
||||
|
||||
/// Feature: Fast PatternToPattern Pass
|
||||
/// Description: Fast PatternToPattern Pass rewrite graph
|
||||
/// Expectation: Get correct Graph
|
||||
TEST_F(TestFastPatternToPatternPass, Mul0) {
|
||||
// a*b + a*c -> a*(b+c)
|
||||
// init
|
||||
auto check = CheckPattern();
|
||||
auto pass = TestFastMul0();
|
||||
|
||||
// build func graph
|
||||
auto a = std::make_shared<AnfNode>(fg_);
|
||||
auto b = std::make_shared<AnfNode>(fg_);
|
||||
auto c = std::make_shared<AnfNode>(fg_);
|
||||
AnfNodePtr ab =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
|
||||
AnfNodePtr ac =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
|
||||
AnfNodePtr add = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
|
||||
|
||||
fg_->set_output(add);
|
||||
auto manager = MakeManager({fg_});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(fg_);
|
||||
fg_->set_manager(manager);
|
||||
}
|
||||
auto func_graph_index = manager->func_graph_index(fg_);
|
||||
GenIndex(fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set.size() == 1);
|
||||
ASSERT_TRUE(mul_set.size() == 2);
|
||||
ASSERT_TRUE(add_set.find(add) != add_set.end());
|
||||
ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
|
||||
ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
|
||||
|
||||
auto new_node = pass.Run(fg_, add);
|
||||
ASSERT_NE(new_node, nullptr);
|
||||
(void)manager->Replace(add, new_node);
|
||||
pass.AfterProcess(add, new_node, fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set_2.size() == 1);
|
||||
ASSERT_TRUE(mul_set_2.size() == 1);
|
||||
ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
|
||||
ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
|
||||
|
||||
// build pattern
|
||||
check.src_pattern_.AddVar("a")
|
||||
.AddVar("b")
|
||||
.AddVar("c")
|
||||
.AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
|
||||
.AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"});
|
||||
|
||||
// pattern engine
|
||||
ASSERT_TRUE(check.build_pattern_map(new_node));
|
||||
|
||||
// check
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
|
||||
ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
|
||||
ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kMulOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
|
||||
}
|
||||
|
||||
/// Feature: Fast PatternToPattern Pass
|
||||
/// Description: Fast PatternToPattern Pass rewrite graph
|
||||
/// Expectation: Get correct Graph
|
||||
TEST_F(TestFastPatternToPatternPass, Mul0NotRoot) {
|
||||
// (a*b + a*c) + d -> a*(b+c) + d
|
||||
// init
|
||||
auto check = CheckPattern();
|
||||
auto pass = TestFastMul0();
|
||||
|
||||
// build func graph
|
||||
auto a = std::make_shared<AnfNode>(fg_);
|
||||
auto b = std::make_shared<AnfNode>(fg_);
|
||||
auto c = std::make_shared<AnfNode>(fg_);
|
||||
auto d = std::make_shared<AnfNode>(fg_);
|
||||
AnfNodePtr ab =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b}, fg_);
|
||||
AnfNodePtr ac =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, c}, fg_);
|
||||
AnfNodePtr add = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
|
||||
AnfNodePtr add1 = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, d}, fg_);
|
||||
|
||||
fg_->set_output(add1);
|
||||
auto manager = MakeManager({fg_});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(fg_);
|
||||
fg_->set_manager(manager);
|
||||
}
|
||||
auto func_graph_index = manager->func_graph_index(fg_);
|
||||
GenIndex(fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 2);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set.size() == 2);
|
||||
ASSERT_TRUE(mul_set.size() == 2);
|
||||
ASSERT_TRUE(add_set.find(add1) != add_set.end());
|
||||
ASSERT_TRUE(add_set.find(add) != add_set.end());
|
||||
ASSERT_TRUE(mul_set.find(ab) != mul_set.end());
|
||||
ASSERT_TRUE(mul_set.find(ac) != mul_set.end());
|
||||
|
||||
auto new_node = pass.Run(fg_, add);
|
||||
ASSERT_NE(new_node, nullptr);
|
||||
(void)manager->Replace(add, new_node);
|
||||
pass.AfterProcess(add, new_node, fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ab) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(ac) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("bc")) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("mul")) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set_2.size() == 2);
|
||||
ASSERT_TRUE(mul_set_2.size() == 1);
|
||||
ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
|
||||
ASSERT_TRUE(add_set_2.find(pass.m_->Get("bc")) != add_set_2.end());
|
||||
ASSERT_TRUE(mul_set_2.find(pass.m_->Get("mul")) != mul_set_2.end());
|
||||
|
||||
// build pattern
|
||||
check.src_pattern_.AddVar("a")
|
||||
.AddVar("b")
|
||||
.AddVar("c")
|
||||
.AddVar("d")
|
||||
.AddCNode("bc", {std::make_shared<Primitive>(kAddOpName), "b", "c"})
|
||||
.AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "a", "bc"})
|
||||
.AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "mul", "d"});
|
||||
|
||||
// pattern engine
|
||||
ASSERT_TRUE(check.build_pattern_map(add1));
|
||||
|
||||
// check
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("b"), b));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("c"), c));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
|
||||
|
||||
ASSERT_EQ(check.m_->Get("bc")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(1), b));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("bc")->cast<CNodePtr>()->input(2), c));
|
||||
|
||||
ASSERT_EQ(check.m_->Get("mul")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kMulOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(1), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("mul")->cast<CNodePtr>()->input(2), check.m_->Get("bc")));
|
||||
|
||||
ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("mul")));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), d));
|
||||
}
|
||||
|
||||
/// Feature: Fast PatternToPattern Pass
|
||||
/// Description: Fast PatternToPattern Pass rewrite graph
|
||||
/// Expectation: Get correct Graph
|
||||
TEST_F(TestFastPatternToPatternPass, Mul1) {
|
||||
// (a * (b1 + d) + (c1 * c2) * d) + e -> (a + d) + e
|
||||
// init
|
||||
auto check = CheckPattern();
|
||||
auto pass = TestFastMul1();
|
||||
|
||||
// build func graph
|
||||
auto a = std::make_shared<AnfNode>(fg_);
|
||||
auto b = std::make_shared<AnfNode>(fg_);
|
||||
auto c1 = std::make_shared<AnfNode>(fg_);
|
||||
auto c2 = std::make_shared<AnfNode>(fg_);
|
||||
auto d = std::make_shared<AnfNode>(fg_);
|
||||
auto e = std::make_shared<AnfNode>(fg_);
|
||||
|
||||
AnfNodePtr b_add_d =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), b, d}, fg_);
|
||||
AnfNodePtr c1_mul_c2 = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1, c2}, fg_);
|
||||
AnfNodePtr a_mul = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a, b_add_d}, fg_);
|
||||
AnfNodePtr d_mul = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), c1_mul_c2, d}, fg_);
|
||||
AnfNodePtr add = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a_mul, d_mul}, fg_);
|
||||
AnfNodePtr add1 = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), add, e}, fg_);
|
||||
|
||||
fg_->set_output(add1);
|
||||
auto manager = MakeManager({fg_});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(fg_);
|
||||
fg_->set_manager(manager);
|
||||
}
|
||||
auto func_graph_index = manager->func_graph_index(fg_);
|
||||
GenIndex(fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 2);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set.size() == 3);
|
||||
ASSERT_TRUE(mul_set.size() == 3);
|
||||
ASSERT_TRUE(add_set.find(add1) != add_set.end());
|
||||
ASSERT_TRUE(add_set.find(add) != add_set.end());
|
||||
ASSERT_TRUE(add_set.find(b_add_d) != add_set.end());
|
||||
ASSERT_TRUE(mul_set.find(a_mul) != mul_set.end());
|
||||
ASSERT_TRUE(mul_set.find(d_mul) != mul_set.end());
|
||||
ASSERT_TRUE(mul_set.find(c1_mul_c2) != mul_set.end());
|
||||
|
||||
auto new_node = pass.Run(fg_, add);
|
||||
ASSERT_NE(new_node, nullptr);
|
||||
(void)manager->Replace(add, new_node);
|
||||
pass.AfterProcess(add, new_node, fg_, func_graph_index);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b_add_d) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c1_mul_c2) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a_mul) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d_mul) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(add1) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(pass.m_->Get("ad")) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(a) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(b) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c1) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(c2) == 0);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(d) == 1);
|
||||
ASSERT_TRUE(func_graph_index->node_degree_.at(e) == 1);
|
||||
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.size() == 2);
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kAddOpName) != func_graph_index->name_to_cnode_.end());
|
||||
ASSERT_TRUE(func_graph_index->name_to_cnode_.find(kMulOpName) != func_graph_index->name_to_cnode_.end());
|
||||
|
||||
auto &add_set_2 = func_graph_index->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set_2 = func_graph_index->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set_2.size() == 1);
|
||||
ASSERT_TRUE(mul_set_2.size() == 1);
|
||||
ASSERT_TRUE(add_set_2.find(add1) != add_set_2.end());
|
||||
ASSERT_TRUE(mul_set_2.find(pass.m_->Get("ad")) != mul_set_2.end());
|
||||
|
||||
// build pattern
|
||||
check.src_pattern_.AddVar("a")
|
||||
.AddVar("d")
|
||||
.AddVar("e")
|
||||
.AddCNode("ad", {std::make_shared<Primitive>(kMulOpName), "a", "d"})
|
||||
.AddCNode("add1", {std::make_shared<Primitive>(kAddOpName), "ad", "e"});
|
||||
|
||||
// pattern engine
|
||||
ASSERT_TRUE(check.build_pattern_map(add1));
|
||||
|
||||
// check
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("a"), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("d"), d));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("e"), e));
|
||||
|
||||
ASSERT_EQ(check.m_->Get("ad")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kMulOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(1), a));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("ad")->cast<CNodePtr>()->input(2), d));
|
||||
|
||||
ASSERT_EQ(check.m_->Get("add1")->cast<CNodePtr>()->inputs().size(), 3);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(0),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(1), check.m_->Get("ad")));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get("add1")->cast<CNodePtr>()->input(2), e));
|
||||
}
|
||||
|
||||
namespace {
|
||||
void Check0(const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
|
||||
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
|
||||
|
||||
ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
|
||||
ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
|
||||
ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
|
||||
|
||||
auto &add_set = fg->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set = fg->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set.size() == kThree);
|
||||
ASSERT_TRUE(mul_set.size() == kOne);
|
||||
ASSERT_TRUE(add_set.find(node_map.at(kAdd)) != add_set.end());
|
||||
ASSERT_TRUE(add_set.find(node_map.at(kAAddB)) != add_set.end());
|
||||
ASSERT_TRUE(add_set.find(node_map.at(kCAddD)) != add_set.end());
|
||||
ASSERT_TRUE(mul_set.find(node_map.at(kMul)) != mul_set.end());
|
||||
}
|
||||
void Check1(const TestFastMul2 &pass, const FuncGraphIndexPtr &fg, const std::map<std::string, AnfNodePtr> &node_map) {
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAAddB)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kCAddD)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kMul)) == kZero);
|
||||
ASSERT_TRUE(fg->node_degree_.at(pass.m_->Get(kMul)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kAdd)) == kOne);
|
||||
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kA)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kB)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kC)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kD)) == kOne);
|
||||
ASSERT_TRUE(fg->node_degree_.at(node_map.at(kE)) == kOne);
|
||||
|
||||
ASSERT_TRUE(fg->name_to_cnode_.size() == kTwo);
|
||||
ASSERT_TRUE(fg->name_to_cnode_.find(kAddOpName) != fg->name_to_cnode_.end());
|
||||
ASSERT_TRUE(fg->name_to_cnode_.find(kMulOpName) != fg->name_to_cnode_.end());
|
||||
|
||||
auto &add_set_2 = fg->name_to_cnode_[kAddOpName];
|
||||
auto &mul_set_2 = fg->name_to_cnode_[kMulOpName];
|
||||
|
||||
ASSERT_TRUE(add_set_2.size() == kThree);
|
||||
ASSERT_TRUE(mul_set_2.size() == kOne);
|
||||
ASSERT_TRUE(add_set_2.find(node_map.at(kAAddB)) != add_set_2.end());
|
||||
ASSERT_TRUE(add_set_2.find(node_map.at(kCAddD)) != add_set_2.end());
|
||||
ASSERT_TRUE(mul_set_2.find(pass.m_->Get(kMul)) != mul_set_2.end());
|
||||
}
|
||||
|
||||
void Check2(const CheckPattern &check, const std::map<std::string, AnfNodePtr> &node_map) {
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kA), node_map.at(kA)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kB), node_map.at(kB)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kC), node_map.at(kC)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kD), node_map.at(kD)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kE), node_map.at(kE)));
|
||||
|
||||
ASSERT_EQ(check.m_->Get(kAAddB)->cast<CNodePtr>()->inputs().size(), kThree);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kZero),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kOne), node_map.at(kA)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAAddB)->cast<CNodePtr>()->input(kTwo), node_map.at(kB)));
|
||||
|
||||
ASSERT_EQ(check.m_->Get(kCAddD)->cast<CNodePtr>()->inputs().size(), kThree);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kZero),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kOne), node_map.at(kC)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kCAddD)->cast<CNodePtr>()->input(kTwo), node_map.at(kD)));
|
||||
|
||||
ASSERT_EQ(check.m_->Get(kMul)->cast<CNodePtr>()->inputs().size(), kThree);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kZero),
|
||||
NewValueNode(std::make_shared<Primitive>(kMulOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kOne), node_map.at(kCAddD)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kMul)->cast<CNodePtr>()->input(kTwo), node_map.at(kAAddB)));
|
||||
|
||||
ASSERT_EQ(check.m_->Get(kAdd)->cast<CNodePtr>()->inputs().size(), kThree);
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kZero),
|
||||
NewValueNode(std::make_shared<Primitive>(kAddOpName))));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kOne), check.m_->Get(kMul)));
|
||||
ASSERT_TRUE(opt::AnfEqual(check.m_->Get(kAdd)->cast<CNodePtr>()->input(kTwo), node_map.at(kE)));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/// Feature: Fast PatternToPattern Pass
|
||||
/// Description: Fast PatternToPattern Pass rewrite graph
|
||||
/// Expectation: Get correct Graph
|
||||
TEST_F(TestFastPatternToPatternPass, Mul2) {
|
||||
// ((a + b) * (c + d)) + e -> ((c + d) * (a + b)) + e
|
||||
// init
|
||||
auto check = CheckPattern();
|
||||
auto pass = TestFastMul2();
|
||||
|
||||
// build func graph
|
||||
auto a = std::make_shared<AnfNode>(fg_);
|
||||
auto b = std::make_shared<AnfNode>(fg_);
|
||||
auto c = std::make_shared<AnfNode>(fg_);
|
||||
auto d = std::make_shared<AnfNode>(fg_);
|
||||
auto e = std::make_shared<AnfNode>(fg_);
|
||||
|
||||
AnfNodePtr a_add_b =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), a, b}, fg_);
|
||||
AnfNodePtr c_add_d =
|
||||
std::make_shared<CNode>(std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), c, d}, fg_);
|
||||
AnfNodePtr mul = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kMulOpName)), a_add_b, c_add_d}, fg_);
|
||||
AnfNodePtr add = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), mul, e}, fg_);
|
||||
|
||||
std::map<std::string, AnfNodePtr> node_map;
|
||||
node_map.emplace("a", a);
|
||||
node_map.emplace("b", b);
|
||||
node_map.emplace("c", c);
|
||||
node_map.emplace("d", d);
|
||||
node_map.emplace("e", e);
|
||||
node_map.emplace("a_add_b", a_add_b);
|
||||
node_map.emplace("c_add_d", c_add_d);
|
||||
node_map.emplace("mul", mul);
|
||||
node_map.emplace("add", add);
|
||||
|
||||
fg_->set_output(add);
|
||||
auto manager = MakeManager({fg_});
|
||||
if (manager) {
|
||||
manager->AddFuncGraph(fg_);
|
||||
fg_->set_manager(manager);
|
||||
}
|
||||
auto func_graph_index = manager->func_graph_index(fg_);
|
||||
GenIndex(fg_, func_graph_index);
|
||||
|
||||
Check0(func_graph_index, node_map);
|
||||
auto new_node = pass.Run(fg_, mul);
|
||||
ASSERT_NE(new_node, nullptr);
|
||||
(void)manager->Replace(mul, new_node);
|
||||
pass.AfterProcess(mul, new_node, fg_, func_graph_index);
|
||||
Check1(pass, func_graph_index, node_map);
|
||||
|
||||
// build pattern
|
||||
check.src_pattern_.AddVar("a")
|
||||
.AddVar("b")
|
||||
.AddVar("c")
|
||||
.AddVar("d")
|
||||
.AddVar("e")
|
||||
.AddCNode("a_add_b", {std::make_shared<Primitive>(kAddOpName), "a", "b"})
|
||||
.AddCNode("c_add_d", {std::make_shared<Primitive>(kAddOpName), "c", "d"})
|
||||
.AddCNode("mul", {std::make_shared<Primitive>(kMulOpName), "c_add_d", "a_add_b"})
|
||||
.AddCNode("add", {std::make_shared<Primitive>(kAddOpName), "mul", "e"});
|
||||
|
||||
// pattern engine
|
||||
ASSERT_TRUE(check.build_pattern_map(add));
|
||||
Check2(check, node_map);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -14,62 +14,11 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "pattern_to_pattern_pass_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TestPatternToPatternPass : public UT::Common {
|
||||
public:
|
||||
TestPatternToPatternPass() : fg_(std::make_shared<FuncGraph>()){};
|
||||
|
||||
public:
|
||||
FuncGraphPtr fg_;
|
||||
};
|
||||
|
||||
class CheckPattern {
|
||||
public:
|
||||
CheckPattern()
|
||||
: m_(std::make_shared<PatternMap>()),
|
||||
src_pattern_(SrcPattern(m_)),
|
||||
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
|
||||
equiv_(std::make_shared<Equiv>()){};
|
||||
bool build_pattern_map(const AnfNodePtr &node) {
|
||||
VarPtr root_g = std::make_shared<Var>("RootG");
|
||||
auto src_pattern_root = SexpToNode(src_pattern_.GetRoot(), root_g, primitive_vars_.get(), multigraph_);
|
||||
auto primitive = GetCNodePrimitive(src_pattern_root);
|
||||
if (IsPrimitiveCNode(node, primitive)) {
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars_);
|
||||
MS_EXCEPTION_IF_NULL(equiv_);
|
||||
equiv_->clear();
|
||||
EquivPtr equiv = pattern_engine_.Match(src_pattern_root, node, *primitive_vars_, equiv_);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return src_pattern_.build_pattern_map(node, equiv);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
PatternMapPtr m_;
|
||||
SrcPattern src_pattern_;
|
||||
PatternEngine pattern_engine_;
|
||||
PrimitiveVarMapPtr primitive_vars_;
|
||||
EquivPtr equiv_;
|
||||
bool multigraph_ = true;
|
||||
};
|
||||
|
||||
namespace {
|
||||
class TestMul0 : public PatternToPatternPass {
|
||||
// a*b + a*c -> a*(b+c)
|
||||
public:
|
||||
|
@ -227,6 +176,15 @@ class TestError1 : public PatternToPatternPass {
|
|||
}
|
||||
bool CheckMatchedDAG(const PatternMap &, const FuncGraphPtr &, const AnfNodePtr &) const override { return true; }
|
||||
};
|
||||
} // namespace
|
||||
|
||||
class TestPatternToPatternPass : public UT::Common {
|
||||
public:
|
||||
TestPatternToPatternPass() : fg_(std::make_shared<FuncGraph>()){};
|
||||
|
||||
public:
|
||||
FuncGraphPtr fg_;
|
||||
};
|
||||
|
||||
/// Feature: PatternToPattern Pass
|
||||
/// Description: PatternToPattern Pass rewrite graph
|
||||
|
@ -422,8 +380,7 @@ TEST_F(TestPatternToPatternPass, Null) {
|
|||
AnfNodePtr add = std::make_shared<CNode>(
|
||||
std::vector<AnfNodePtr>{NewValueNode(std::make_shared<Primitive>(kAddOpName)), ab, ac}, fg_);
|
||||
|
||||
auto new_node = pass.Run(fg_, add);
|
||||
ASSERT_EQ(new_node, nullptr);
|
||||
EXPECT_THROW(pass.Run(fg_, add), std::runtime_error);
|
||||
}
|
||||
|
||||
/// Feature: PatternToPattern Pass
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
|
||||
#define MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/value.h"
|
||||
#include "include/common/utils/utils.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
|
||||
#define private public
|
||||
#define protected public
|
||||
#include "backend/common/optimizer/pattern_to_pattern.h"
|
||||
#undef private
|
||||
#undef protected
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class CheckPattern {
|
||||
public:
|
||||
CheckPattern()
|
||||
: m_(std::make_shared<PatternMap>()),
|
||||
src_pattern_(SrcPattern(m_)),
|
||||
pattern_engine_(PatternEngine(std::make_shared<Visitor>())),
|
||||
primitive_vars_(std::make_shared<PrimitiveVarMap>()),
|
||||
equiv_(std::make_shared<Equiv>()){};
|
||||
bool build_pattern_map(const AnfNodePtr &node) {
|
||||
VarPtr root_g = std::make_shared<Var>("RootG");
|
||||
auto src_pattern_root = SexpToNode(src_pattern_.GetRoot(), root_g, primitive_vars_.get(), multigraph_);
|
||||
auto primitive = GetCNodePrimitive(src_pattern_root);
|
||||
if (IsPrimitiveCNode(node, primitive)) {
|
||||
MS_EXCEPTION_IF_NULL(primitive_vars_);
|
||||
MS_EXCEPTION_IF_NULL(equiv_);
|
||||
equiv_->clear();
|
||||
EquivPtr equiv = pattern_engine_.Match(src_pattern_root, node, *primitive_vars_, equiv_);
|
||||
if (equiv != nullptr && !equiv->empty()) {
|
||||
return src_pattern_.build_pattern_map(node, equiv);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
PatternMapPtr m_;
|
||||
SrcPattern src_pattern_;
|
||||
PatternEngine pattern_engine_;
|
||||
PrimitiveVarMapPtr primitive_vars_;
|
||||
EquivPtr equiv_;
|
||||
bool multigraph_ = true;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_TESTS_UT_CPP_PRE_ACTIVATE_COMMON_PATTERN_TO_PATTERN_PASS_UTILS_H_
|
Loading…
Reference in New Issue