forked from mindspore-Ecosystem/mindspore
Refactor GraphKernelCluster
Use a disjoint-set to maintain the clusters before building anf-graph, to speed up the process of building graph. Dump the change of nodes to file if the graphkernel flag "dump_as_text" is set. The algorithm of CheckCircle is unchanged. Reuse the FuseNodesToSubGraph interface.
This commit is contained in:
parent
3ca52269db
commit
3bf0d8a19a
|
@ -1,197 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
|
||||
if (cur_node == node) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (IsFusibleOp(node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
|
||||
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (AnfAlgo::IsGraphKernel(prev_node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
}
|
||||
return EXCLUDE;
|
||||
}
|
||||
|
||||
// The GetItem node should be fused with its real input and users.
|
||||
// If its real input is not in the fuse_list, the GetItem should be excluded.
|
||||
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
|
||||
if (fused_op.empty()) return AnfNodePtrList();
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
|
||||
|
||||
auto mng = fused_op[0]->func_graph()->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
AnfNodePtrList remove_list;
|
||||
for (auto getitem : fused_op_set) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
|
||||
|
||||
// GetItem should be fused with its real input.
|
||||
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
|
||||
if (check_include(prev_node) == EXCLUDE) {
|
||||
remove_list.push_back(getitem);
|
||||
break;
|
||||
}
|
||||
|
||||
// GetItem should be fused with its all users.
|
||||
const auto &users = mng->node_users()[getitem];
|
||||
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
|
||||
return check_include(user.first) == EXCLUDE;
|
||||
})) {
|
||||
remove_list = DeepLinkedGraphSearch(getitem, check_include);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!remove_list.empty()) {
|
||||
for (auto node : remove_list) {
|
||||
fused_op_set.erase(node);
|
||||
}
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// keep the original order of fused_op.
|
||||
AnfNodePtrList result;
|
||||
for (auto node : fused_op) {
|
||||
if (fused_op_set.count(node)) {
|
||||
result.push_back(node);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
|
||||
// Search fusable nodes according input direction.
|
||||
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
|
||||
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
|
||||
if (used_nodes.size() > 1) {
|
||||
used_nodes = RemoveCircle(used_nodes, dep_pri);
|
||||
}
|
||||
used_nodes = RemoveWildGetitem(used_nodes);
|
||||
TopoSortForNodeList(&used_nodes);
|
||||
return used_nodes;
|
||||
}
|
||||
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
|
||||
std::unordered_set<AnfNodePtr> *fused_ops) {
|
||||
bool changed = false;
|
||||
auto mng = kernel_graph->manager();
|
||||
|
||||
// depend_prior[depend] = pair(prior, behind)
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
|
||||
// InitDependPrior(todos, &depend_prior);
|
||||
|
||||
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
|
||||
auto node = (*iter)->cast<CNodePtr>();
|
||||
if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) {
|
||||
continue;
|
||||
}
|
||||
bool is_fusible_op = IsFusibleOp(node);
|
||||
if (!is_fusible_op || !kernel_graph->nodes().contains(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
|
||||
if (fuse_nodes.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (fuse_nodes.size() == 1) {
|
||||
// Do not fuse a single GraphKernel again.
|
||||
// Do not fuse a single Assign.
|
||||
if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
changed = true;
|
||||
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
|
||||
AnfNodePtr fused_new_node;
|
||||
AnfNodePtrList old_outputs;
|
||||
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion");
|
||||
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
|
||||
}
|
||||
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
|
||||
return changed;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool FuseBasicOps(const FuncGraphPtr &func_graph) {
|
||||
std::unordered_set<AnfNodePtr> fused_ops;
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
std::reverse(todos.begin(), todos.end());
|
||||
return FuseBasicOps(func_graph, todos, &fused_ops);
|
||||
}
|
||||
|
||||
void EliminateGetitem(const FuncGraphPtr &func_graph) {
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
auto todos = TopoSort(func_graph->get_return());
|
||||
for (auto node : todos) {
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
bool changed = FuseBasicOps(func_graph);
|
||||
if (changed) {
|
||||
EliminateGetitem(func_graph);
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,36 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool FuseBasicOps(const FuncGraphPtr &kernel_graph);
|
||||
|
||||
class BasicOpsFusion : public Pass {
|
||||
public:
|
||||
BasicOpsFusion() : Pass("basic_ops_fusion") {}
|
||||
~BasicOpsFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>;
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
|
|
@ -18,6 +18,7 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -99,7 +100,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
|
|||
continue;
|
||||
}
|
||||
// Cast cannot fuse with its input
|
||||
if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) {
|
||||
if (IsClusterableOp((cast_node->cast<CNodePtr>())->input(1))) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,217 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/ordered_set.h"
|
||||
#include "utils/ordered_map.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "vm/segment_runner.h"
|
||||
#include "debug/draw.h"
|
||||
#include "debug/anf_ir_dump.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) {
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
for (auto &root : roots) {
|
||||
auto tmp = DeepLinkedGraphSearch(root, include);
|
||||
inputs.insert(inputs.end(), tmp.begin(), tmp.end());
|
||||
}
|
||||
return inputs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
|
||||
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
|
||||
return false;
|
||||
}
|
||||
circle_nodes->clear();
|
||||
|
||||
auto InputEdges = [&depend_prior](const CNodePtr &cnode) {
|
||||
std::set<AnfNodePtr> edges;
|
||||
auto range = depend_prior.equal_range(cnode);
|
||||
for (auto iter = range.first; iter != range.second; ++iter) {
|
||||
edges.insert(iter->second.first);
|
||||
}
|
||||
auto inputs = cnode->inputs();
|
||||
for (auto input : inputs) {
|
||||
edges.insert(input);
|
||||
}
|
||||
return edges;
|
||||
};
|
||||
|
||||
// consider prior depend both in fused_op_set
|
||||
auto range = depend_prior.equal_range(check_node);
|
||||
for (auto iter = range.first; iter != range.second; ++iter) {
|
||||
if (fused_op_set.count(iter->second.first)) {
|
||||
circle_nodes->push_back(iter->second.first);
|
||||
}
|
||||
}
|
||||
|
||||
std::set<AnfNodePtr> cached_done_set;
|
||||
auto cnode = check_node->cast<CNodePtr>();
|
||||
const auto &inputs = InputEdges(cnode);
|
||||
// there is a input not in fused_op_set, but the input depends on the fused_op_set
|
||||
for (auto input : inputs) {
|
||||
if (input->isa<CNode>() && !fused_op_set.count(input)) {
|
||||
bool has_circle = false;
|
||||
std::set<AnfNodePtr> done;
|
||||
std::vector<AnfNodePtr> todos = {input};
|
||||
while (!todos.empty()) {
|
||||
auto node = todos.back();
|
||||
todos.pop_back();
|
||||
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
done.insert(node);
|
||||
if (fused_op_set.count(node)) {
|
||||
has_circle = true;
|
||||
circle_nodes->push_back(node);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->isa<CNode>()) {
|
||||
auto cnode_ptr = node->cast<CNodePtr>();
|
||||
for (auto it : InputEdges(cnode_ptr)) {
|
||||
if (it->isa<CNode>()) {
|
||||
todos.push_back(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (has_circle) {
|
||||
cached_done_set.insert(done.begin(), done.end());
|
||||
} else {
|
||||
cached_unconnected_set->insert(done.begin(), done.end());
|
||||
}
|
||||
done.clear();
|
||||
}
|
||||
}
|
||||
|
||||
return !circle_nodes->empty();
|
||||
}
|
||||
|
||||
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
|
||||
std::set<AnfNodePtr> cached_unconnected_set;
|
||||
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
|
||||
auto include = [&fused_op_set](const AnfNodePtr &node) {
|
||||
if (fused_op_set.count(node)) {
|
||||
return FOLLOW;
|
||||
}
|
||||
return EXCLUDE;
|
||||
};
|
||||
|
||||
std::vector<AnfNodePtr> circle_nodes;
|
||||
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
|
||||
circle_nodes.clear();
|
||||
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
|
||||
// delete the circle node and the node which depend on the circle node in fused op
|
||||
if (has_circle) {
|
||||
std::vector<AnfNodePtr> erase_nodes;
|
||||
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
|
||||
for (auto erase_node : erase_nodes) {
|
||||
fused_op_set.erase(erase_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> res;
|
||||
for (auto node : fused_op) {
|
||||
if (fused_op_set.count(node)) {
|
||||
res.push_back(node);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
|
||||
if (lst->size() < 2) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> res;
|
||||
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
|
||||
OrderedMap<AnfNodePtr, std::set<AnfNodePtr>> ins;
|
||||
OrderedMap<AnfNodePtr, OrderedSet<AnfNodePtr>> outs;
|
||||
std::queue<AnfNodePtr> q;
|
||||
for (auto node : *lst) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (auto input : cnode->inputs()) {
|
||||
if (!node_sets.count(input)) {
|
||||
continue;
|
||||
}
|
||||
// out_degree
|
||||
outs[input].insert(node);
|
||||
// in_degree
|
||||
ins[node].insert(input);
|
||||
}
|
||||
if (!ins.count(node)) {
|
||||
ins[node] = {};
|
||||
}
|
||||
}
|
||||
|
||||
for (auto p : ins) {
|
||||
if (p.second.size() == 0) {
|
||||
q.push(p.first);
|
||||
}
|
||||
}
|
||||
|
||||
while (!q.empty()) {
|
||||
auto node = q.front();
|
||||
q.pop();
|
||||
res.push_back(node);
|
||||
if (!outs.count(node)) {
|
||||
continue;
|
||||
}
|
||||
for (auto out : outs[node]) {
|
||||
if (!ins.count(out)) {
|
||||
continue;
|
||||
}
|
||||
ins[out].erase(node);
|
||||
if (ins[out].size() == 0) {
|
||||
q.push(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lst->assign(res.begin(), res.end());
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,37 +0,0 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
||||
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
|
||||
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior);
|
||||
|
||||
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
|
|
@ -0,0 +1,477 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <fstream>
|
||||
|
||||
#include "base/core_ops.h"
|
||||
#include "ir/graph_utils.h"
|
||||
#include "debug/common.h"
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/optimizer/pass/getitem_tuple.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
std::vector<PrimitivePtr> GetClusterableOpList() {
|
||||
std::vector<PrimitivePtr> clusterable_ops = {
|
||||
prim::kPrimAbs,
|
||||
prim::kPrimRound,
|
||||
prim::kPrimNeg,
|
||||
prim::kPrimExp,
|
||||
prim::kPrimAdd,
|
||||
prim::kPrimCast,
|
||||
prim::kPrimMul,
|
||||
prim::kPrimMinimum,
|
||||
prim::kPrimMaximum,
|
||||
prim::kPrimLog,
|
||||
prim::kPrimPow,
|
||||
prim::kPrimSub,
|
||||
prim::kPrimRsqrt,
|
||||
prim::kPrimSqrt,
|
||||
prim::kPrimAddN,
|
||||
prim::kPrimReciprocal,
|
||||
prim::kPrimTanh,
|
||||
prim::kPrimReshape,
|
||||
prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv,
|
||||
prim::kPrimReduceSum,
|
||||
prim::kPrimEqual,
|
||||
prim::kPrimAssign,
|
||||
prim::kPrimInplaceAssign,
|
||||
#if ENABLE_D
|
||||
prim::kPrimMatMul,
|
||||
prim::KPrimTransData,
|
||||
#elif ENABLE_GPU
|
||||
prim::kPrimReduceMax,
|
||||
prim::kPrimReduceMin,
|
||||
prim::kPrimGreater,
|
||||
prim::kPrimLess,
|
||||
prim::kPrimGreaterEqual,
|
||||
prim::kPrimLessEqual,
|
||||
prim::kPrimSelect,
|
||||
#endif
|
||||
};
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
|
||||
return clusterable_ops;
|
||||
}
|
||||
|
||||
size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) {
|
||||
AnfNodePtrList node_list;
|
||||
kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list);
|
||||
return node_list.size();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool IsClusterableOp(const AnfNodePtr &node) {
|
||||
if (IsKeepBasicNode(node)) {
|
||||
return false;
|
||||
}
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
return true;
|
||||
}
|
||||
auto op_list = GetClusterableOpList();
|
||||
bool node_in_oplist = std::any_of(op_list.begin(), op_list.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
if (!node_in_oplist) {
|
||||
return false;
|
||||
}
|
||||
#if ENABLE_D
|
||||
// For AICPU operators, only the Reshape can be clustered.
|
||||
if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
class Graph {
|
||||
struct Cluster {
|
||||
size_t cluster_id_; // node_id of the representative.
|
||||
size_t cluster_size_{1}; // size of cluster, composite node is considered as one node.
|
||||
size_t basic_op_cnt_{1}; // basic node count, the inner nodes of composite node are counted.
|
||||
std::set<size_t> inputs_; // inputs' cluster_id.
|
||||
size_t seed_{0}; // visited flag of dfs.
|
||||
|
||||
Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map)
|
||||
: cluster_id_(node_id) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
basic_op_cnt_ = 0;
|
||||
} else if (AnfAlgo::IsGraphKernel(node)) {
|
||||
// the basic_op_cnt_ is used to limit the composite op size
|
||||
basic_op_cnt_ = CountGraphKernelInnerNodes(node);
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
for (const auto &inp : cnode->inputs()) {
|
||||
auto iter = node_idx_map.find(inp);
|
||||
if (iter != node_idx_map.end()) {
|
||||
// At the beginning, cluster_id is equal to node_id
|
||||
inputs_.insert(iter->second);
|
||||
}
|
||||
}
|
||||
}
|
||||
~Cluster() = default;
|
||||
|
||||
void Merge(Cluster *other_cluster) {
|
||||
other_cluster->cluster_id_ = cluster_id_;
|
||||
cluster_size_ += other_cluster->cluster_size_;
|
||||
basic_op_cnt_ += other_cluster->basic_op_cnt_;
|
||||
std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(),
|
||||
[this](size_t inp) { this->inputs_.insert(inp); });
|
||||
other_cluster->Clean();
|
||||
}
|
||||
|
||||
// clean the info to free memory.
|
||||
void Clean() {
|
||||
inputs_.clear();
|
||||
cluster_size_ = 0;
|
||||
basic_op_cnt_ = 0;
|
||||
}
|
||||
}; // struct Cluster
|
||||
|
||||
public:
|
||||
// Init and build graph
|
||||
Graph(const AnfNodePtrList &nodes, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) {
|
||||
clusters_.reserve(nodes.size());
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
clusters_.emplace_back(i, nodes[i], node_idx_map);
|
||||
}
|
||||
}
|
||||
~Graph() = default;
|
||||
|
||||
// find the representative of the cluster
|
||||
int Find(size_t node_id) {
|
||||
size_t &pre_id = clusters_[node_id].cluster_id_;
|
||||
return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
|
||||
}
|
||||
|
||||
// merge clusters, the smallest cluster id will be the new cluster id.
|
||||
void Merge(const std::set<size_t> &candidates) {
|
||||
for (auto iter = ++candidates.begin(); iter != candidates.end(); ++iter) {
|
||||
clusters_[*candidates.begin()].Merge(&clusters_[*iter]);
|
||||
}
|
||||
}
|
||||
|
||||
// Collect nodes together that are in the same cluster.
|
||||
std::vector<std::vector<size_t>> CollectClusters() {
|
||||
std::vector<std::vector<size_t>> cluster_map(clusters_.size());
|
||||
for (size_t i = 0; i < clusters_.size(); i++) {
|
||||
cluster_map[Find(i)].push_back(i);
|
||||
}
|
||||
return cluster_map;
|
||||
}
|
||||
|
||||
using VisitFunc = std::function<IncludeType(size_t)>;
|
||||
void Dfs(size_t node_id, VisitFunc visitor) {
|
||||
++seen_;
|
||||
return DepthFirstSearch(Find(node_id), visitor);
|
||||
}
|
||||
|
||||
// Get cluster size
|
||||
size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
|
||||
|
||||
// Get cluster's basic op count
|
||||
size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; }
|
||||
|
||||
// Get cluster's inputs
|
||||
const std::set<size_t> &GetInputs(size_t cluster_id) {
|
||||
cluster_id = Find(cluster_id);
|
||||
RefreshInputs(cluster_id);
|
||||
return clusters_[cluster_id].inputs_;
|
||||
}
|
||||
|
||||
private:
|
||||
void RefreshInputs(size_t i) {
|
||||
auto &inputs = clusters_[i].inputs_;
|
||||
for (auto iter = inputs.begin(); iter != inputs.end();) {
|
||||
size_t new_id = Find(*iter);
|
||||
if (new_id != *iter) {
|
||||
iter = inputs.erase(iter);
|
||||
inputs.insert(new_id);
|
||||
} else {
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
inputs.erase(i);
|
||||
}
|
||||
|
||||
void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
|
||||
if (clusters_[cluster_id].seed_ >= seen_) return;
|
||||
clusters_[cluster_id].seed_ = seen_;
|
||||
if (visitor(cluster_id) != FOLLOW) {
|
||||
return;
|
||||
}
|
||||
// traverse inputs in descending order.
|
||||
const auto &inputs = GetInputs(cluster_id);
|
||||
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
|
||||
DepthFirstSearch(*iter, visitor);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Cluster> clusters_;
|
||||
size_t seen_{0};
|
||||
}; // class Graph
|
||||
|
||||
class CircleChecker {
|
||||
public:
|
||||
explicit CircleChecker(GraphPtr graph) : graph_(graph) {}
|
||||
~CircleChecker() = default;
|
||||
|
||||
void RemoveCircle(std::set<size_t> *candidates) {
|
||||
if (candidates->size() <= 1) {
|
||||
return;
|
||||
}
|
||||
candidates_ = candidates;
|
||||
std::vector<size_t> tmp_list(candidates->begin(), candidates->end());
|
||||
for (auto c : tmp_list) {
|
||||
if (!candidates->count(c)) continue;
|
||||
circle_nodes_.clear();
|
||||
if (CheckCircle(c)) {
|
||||
RemoveCircleNodesFromCandidates();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
|
||||
*
|
||||
* algorithm:
|
||||
* Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
|
||||
* If it depends on a node that belongs to candidates, it will form a circle.
|
||||
* e.g. A -> x -> ... -> B
|
||||
* -> y -> ... -> C
|
||||
* In this case, A, B and C are candidates while x and y are not.
|
||||
* Both x and y are inputs of A. assumes A is the basenode.
|
||||
* When searching from x, the B will be found and added into circle_nodes list,
|
||||
* and then when searching from y, the C will be found and added into circle_nodes list.
|
||||
*/
|
||||
bool CheckCircle(size_t basenode) {
|
||||
const auto &inputs = graph_->GetInputs(basenode);
|
||||
std::set<size_t> visited_circle_nodes;
|
||||
for (auto x : inputs) {
|
||||
if (candidates_->count(x)) continue;
|
||||
bool has_circle = false;
|
||||
std::set<size_t> done;
|
||||
auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
|
||||
if (done.count(node_id) || acyclic_nodes_.count(node_id) || visited_circle_nodes.count(node_id)) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
done.insert(node_id);
|
||||
if (candidates_->count(node_id)) {
|
||||
has_circle = true;
|
||||
circle_nodes_.push_back(node_id);
|
||||
return EXCLUDE;
|
||||
}
|
||||
// all nodes are indexed by topo order,
|
||||
// so if the node_id is less than the minimal candidate, a cycle cannot be formed from this node.
|
||||
if (candidates_->empty() || node_id < *candidates_->begin()) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
return FOLLOW;
|
||||
};
|
||||
graph_->Dfs(x, vis_func);
|
||||
if (has_circle) {
|
||||
visited_circle_nodes.insert(done.begin(), done.end());
|
||||
} else {
|
||||
acyclic_nodes_.insert(done.begin(), done.end());
|
||||
}
|
||||
}
|
||||
return !circle_nodes_.empty();
|
||||
}
|
||||
|
||||
// remove all circle nodes from candidates
|
||||
void RemoveCircleNodesFromCandidates() {
|
||||
auto remove_from_candidates = [this](size_t node_id) {
|
||||
if (candidates_->count(node_id)) {
|
||||
candidates_->erase(node_id);
|
||||
return FOLLOW;
|
||||
}
|
||||
return EXCLUDE;
|
||||
};
|
||||
for (auto node : circle_nodes_) {
|
||||
graph_->Dfs(node, remove_from_candidates);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
GraphPtr graph_; // bind the global graph
|
||||
std::set<size_t> *candidates_{nullptr}; // bind the input candidates
|
||||
std::vector<size_t> circle_nodes_;
|
||||
std::set<size_t> acyclic_nodes_;
|
||||
}; // CircleChecker
|
||||
|
||||
std::set<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
|
||||
std::set<size_t> candidates;
|
||||
auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
|
||||
const AnfNodePtr &node = this->nodes_[cluster_id];
|
||||
if (node->func_graph() != func_graph) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
candidates.insert(cluster_id);
|
||||
// Do not search from clustered node again.
|
||||
if (this->graph_->GetSize(cluster_id) > 1) {
|
||||
return NOFOLLOW;
|
||||
}
|
||||
return FOLLOW;
|
||||
};
|
||||
graph_->Dfs(basenode_id, include);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
|
||||
bool changed = false;
|
||||
for (int i = nodes_.size() - 1; i >= 0; i--) {
|
||||
// if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
|
||||
if (graph_->GetSize(i) > 1) {
|
||||
continue;
|
||||
}
|
||||
auto candidates = FindCandidates(i);
|
||||
CircleChecker(graph_).RemoveCircle(&candidates);
|
||||
RemoveWildGetitem(&candidates);
|
||||
if (candidates.empty()) continue;
|
||||
// merge candidates into one cluster
|
||||
graph_->Merge(candidates);
|
||||
}
|
||||
|
||||
// Rebuild func_graphs
|
||||
auto clusters = graph_->CollectClusters();
|
||||
for (size_t i = 0; i < clusters.size(); i++) {
|
||||
auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
|
||||
return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
|
||||
});
|
||||
if (node_without_getitem == 0) continue;
|
||||
if (node_without_getitem == 1) {
|
||||
// Do not cluster a single GraphKernel again.
|
||||
// Do not cluster a single Assign.
|
||||
const auto &node = nodes_[clusters[i][0]];
|
||||
if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
CreateFuncGraph(func_graph, clusters[i]);
|
||||
changed = true;
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
|
||||
AnfNodePtrList old_nodes;
|
||||
AnfNodePtr new_node;
|
||||
std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
|
||||
[this](size_t id) { return this->nodes_[id]; });
|
||||
std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
|
||||
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
|
||||
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
|
||||
if (context::GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
DumpClusterInfo(old_nodes, new_node);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
|
||||
for (const auto &node : old_nodes) {
|
||||
dump_buf_ << " " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
|
||||
}
|
||||
dump_buf_ << "=======================" << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
void GraphKernelCluster::DumpToFile() {
|
||||
#ifdef ENABLE_DUMP_IR
|
||||
auto pathname = std::string("./") + kGraphKernelDumpPath + "/graph_kernel_cluster.txt";
|
||||
auto realpath = Common::GetRealPath(pathname);
|
||||
if (!realpath.has_value()) {
|
||||
MS_LOG(ERROR) << "Get real path failed. path=" << pathname;
|
||||
return;
|
||||
}
|
||||
std::ofstream fout(realpath.value(), std::ios::app);
|
||||
if (!fout.is_open()) {
|
||||
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!";
|
||||
return;
|
||||
}
|
||||
fout << dump_buf_.str() << std::endl;
|
||||
fout.close();
|
||||
#endif
|
||||
}
|
||||
|
||||
// The GetItem node should be clustered with its real input.
|
||||
// If its real input is not in the candidates, the GetItem should be excluded.
|
||||
void GraphKernelCluster::RemoveWildGetitem(std::set<size_t> *candidates) {
|
||||
for (auto iter = candidates->begin(); iter != candidates->end();) {
|
||||
size_t cluster_id = *iter;
|
||||
/*The implied condition is graph->GetSize(cluster_id) == 1*/
|
||||
if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
|
||||
const auto &inputs = graph_->GetInputs(cluster_id);
|
||||
if (inputs.size() != 1) {
|
||||
MS_LOG(ERROR) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
|
||||
candidates->clear();
|
||||
return;
|
||||
}
|
||||
auto prev_id = *(inputs.begin());
|
||||
if (!candidates->count(prev_id)) {
|
||||
iter = candidates->erase(iter);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
|
||||
// process cnode only
|
||||
nodes_ = TopoSort(func_graph->get_return(), SuccIncoming,
|
||||
[](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
|
||||
for (size_t i = 0; i < nodes_.size(); i++) {
|
||||
node_idx_map_[nodes_[i]] = i;
|
||||
}
|
||||
graph_ = std::make_shared<Graph>(nodes_, node_idx_map_);
|
||||
MS_EXCEPTION_IF_NULL(graph_);
|
||||
}
|
||||
|
||||
bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
|
||||
auto mng = func_graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(mng);
|
||||
Init(func_graph);
|
||||
bool changed = Process(func_graph);
|
||||
if (changed) {
|
||||
if (context::GraphKernelFlags::GetInstance().dump_as_text) {
|
||||
DumpToFile();
|
||||
}
|
||||
mng->RemoveRoots();
|
||||
mng->KeepRoots({func_graph});
|
||||
}
|
||||
Clean();
|
||||
return changed;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class Graph;
|
||||
using GraphPtr = std::shared_ptr<Graph>;
|
||||
class GraphKernelCluster : public Pass {
|
||||
public:
|
||||
GraphKernelCluster() : Pass("graph_kernel_cluster") {}
|
||||
~GraphKernelCluster() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
private:
|
||||
void Init(const FuncGraphPtr &func_graph);
|
||||
bool Process(const FuncGraphPtr &func_graph);
|
||||
std::set<size_t> FindCandidates(size_t basenode_id);
|
||||
void RemoveWildGetitem(std::set<size_t> *candidates);
|
||||
void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id);
|
||||
void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node);
|
||||
void DumpToFile();
|
||||
void Clean() {
|
||||
std::vector<AnfNodePtr>().swap(nodes_);
|
||||
node_idx_map_.clear();
|
||||
graph_ = nullptr;
|
||||
}
|
||||
|
||||
GraphPtr graph_{nullptr};
|
||||
std::vector<AnfNodePtr> nodes_;
|
||||
std::unordered_map<AnfNodePtr, size_t> node_idx_map_;
|
||||
std::stringstream dump_buf_;
|
||||
};
|
||||
|
||||
bool IsClusterableOp(const AnfNodePtr &node);
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
|
|
@ -593,77 +593,6 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
|
|||
return name.str();
|
||||
}
|
||||
|
||||
std::vector<PrimitivePtr> GetFusibleOpList() {
|
||||
#if ENABLE_D
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
|
||||
prim::KPrimTransData};
|
||||
#elif ENABLE_GPU
|
||||
std::vector<PrimitivePtr> fusible_basic_ops = {
|
||||
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
|
||||
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
|
||||
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
|
||||
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
|
||||
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
|
||||
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
|
||||
prim::kPrimLess, prim::kPrimInplaceAssign};
|
||||
#else
|
||||
std::vector<PrimitivePtr> fusible_basic_ops;
|
||||
#endif
|
||||
const auto &flags = context::GraphKernelFlags::GetInstance();
|
||||
OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
|
||||
return fusible_basic_ops;
|
||||
}
|
||||
|
||||
bool CheckProcessor(const AnfNodePtr &node, kernel::Processor processor = kernel::Processor::AICORE) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
auto node_kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (node_kernel_info == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo();
|
||||
if (node_build_info == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return node_build_info->processor() == processor;
|
||||
}
|
||||
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node) {
|
||||
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
|
||||
#if ENABLE_D
|
||||
if (!CheckProcessor(node)) {
|
||||
std::vector<PrimitivePtr> fused_aicpu_op = {prim::kPrimExpandDims, prim::kPrimReshape};
|
||||
if (!std::any_of(fused_aicpu_op.begin(), fused_aicpu_op.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return std::any_of(basic_ops.begin(), basic_ops.end(),
|
||||
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
|
||||
}
|
||||
|
||||
bool IsFusibleOp(const AnfNodePtr &node) {
|
||||
#if ENABLE_D
|
||||
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
|
||||
"LambNextMV", "LambUpdateWithLR"};
|
||||
if (AnfAlgo::IsGraphKernel(node)) {
|
||||
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
|
||||
if (fg_attr != nullptr) {
|
||||
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
|
||||
}
|
||||
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -674,37 +603,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
|
|||
#endif
|
||||
}
|
||||
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
|
||||
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
|
||||
|
||||
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
|
||||
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
|
||||
MS_LOG(ERROR) << "Need real outputs of makeTuple";
|
||||
}
|
||||
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
|
||||
if (iter->first == outputs[out_idx]) {
|
||||
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
if (iter->second.first == outputs[out_idx]) {
|
||||
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
|
||||
iter = depend_prior->erase(iter);
|
||||
continue;
|
||||
}
|
||||
++iter;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto item : new_fuse_cnode_dep_pri) {
|
||||
depend_prior->insert(item);
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetFormat(const AnfNodePtr &node) {
|
||||
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
|
|
|
@ -47,6 +47,8 @@ constexpr auto kJsonKeyMultiGraph = "multi_graph";
|
|||
constexpr auto kJsonKeyGraphDesc = "graph_desc";
|
||||
constexpr auto kJsonKeyGraphMode = "graph_mode";
|
||||
|
||||
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
|
||||
|
||||
struct DataInfo {
|
||||
std::string format{kOpFormat_DEFAULT};
|
||||
ShapeVector shape{1};
|
||||
|
@ -75,12 +77,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
|
|||
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
|
||||
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
|
||||
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
|
||||
std::vector<PrimitivePtr> GetFusibleOpList();
|
||||
bool IsBasicFuseOp(const AnfNodePtr &node);
|
||||
bool IsFusibleOp(const AnfNodePtr &node);
|
||||
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
|
||||
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
|
||||
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
|
||||
|
||||
std::string GetFormat(const AnfNodePtr &node);
|
||||
TypePtr GetType(const AnfNodePtr &node);
|
||||
|
|
|
@ -25,7 +25,7 @@
|
|||
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
|
||||
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
|
||||
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
|
||||
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
|
||||
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
|
||||
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
|
||||
|
@ -65,8 +65,8 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
|
|||
// Expand complex basic kernels to composite kernels
|
||||
pm->AddPass(std::make_shared<GraphKernelExpander>());
|
||||
|
||||
// Fuse basic kernels and composite kernels
|
||||
pm->AddPass(std::make_shared<BasicOpsFusion>());
|
||||
// Cluster basic kernels and composite kernels
|
||||
pm->AddPass(std::make_shared<GraphKernelCluster>());
|
||||
|
||||
// Eliminate the outputs without external user
|
||||
pm->AddPass(std::make_shared<EliminateRedundantOutput>());
|
||||
|
|
Loading…
Reference in New Issue