Refactor GraphKernelCluster

Use a disjoint-set to maintain the clusters before building anf-graph, to speed up the process of building graph.
Dump the change of nodes to file if the graphkernel flag "dump_as_text" is set.

The algorithm of CheckCircle is unchanged.
Reuse the FuseNodesToSubGraph interface.
This commit is contained in:
dayschan 2021-04-30 18:45:39 +08:00
parent 3ca52269db
commit 3bf0d8a19a
10 changed files with 545 additions and 598 deletions

View File

@ -1,197 +0,0 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include <algorithm>
#include <map>
#include <memory>
#include <unordered_set>
#include <unordered_map>
#include <utility>
#include <set>
#include <string>
#include <vector>
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
namespace opt {
namespace {
IncludeType IncludeFusedBasicOpForward(const AnfNodePtr &cur_node, const AnfNodePtr &node) {
if (cur_node == node) {
return FOLLOW;
}
if (IsFusibleOp(node)) {
return FOLLOW;
}
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimTupleGetItem)) {
auto prev_node = node->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (AnfAlgo::IsGraphKernel(prev_node)) {
return FOLLOW;
}
}
return EXCLUDE;
}
// The GetItem node should be fused with its real input and users.
// If its real input is not in the fuse_list, the GetItem should be excluded.
AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
if (fused_op.empty()) return AnfNodePtrList();
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto check_include = [&fused_op_set](const AnfNodePtr &node) { return fused_op_set.count(node) ? FOLLOW : EXCLUDE; };
auto mng = fused_op[0]->func_graph()->manager();
MS_EXCEPTION_IF_NULL(mng);
bool changed = true;
while (changed) {
changed = false;
AnfNodePtrList remove_list;
for (auto getitem : fused_op_set) {
if (!AnfAlgo::CheckPrimitiveType(getitem, prim::kPrimTupleGetItem)) continue;
// GetItem should be fused with its real input.
auto prev_node = getitem->cast<CNodePtr>()->input(kRealInputNodeIndexInTupleGetItem);
if (check_include(prev_node) == EXCLUDE) {
remove_list.push_back(getitem);
break;
}
// GetItem should be fused with its all users.
const auto &users = mng->node_users()[getitem];
if (std::any_of(users.begin(), users.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
}
if (!remove_list.empty()) {
for (auto node : remove_list) {
fused_op_set.erase(node);
}
changed = true;
}
}
// keep the original order of fused_op.
AnfNodePtrList result;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
result.push_back(node);
}
}
return result;
}
std::vector<AnfNodePtr> FindFuseCNodes(const CNodePtr &cnode,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &dep_pri) {
// Search fusable nodes according input direction.
auto include_func_forward = std::bind(IncludeFusedBasicOpForward, cnode, std::placeholders::_1);
auto used_nodes = DeepLinkedGraphSearch(cnode, include_func_forward);
if (used_nodes.size() > 1) {
used_nodes = RemoveCircle(used_nodes, dep_pri);
}
used_nodes = RemoveWildGetitem(used_nodes);
TopoSortForNodeList(&used_nodes);
return used_nodes;
}
bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vector<AnfNodePtr> &todos,
std::unordered_set<AnfNodePtr> *fused_ops) {
bool changed = false;
auto mng = kernel_graph->manager();
// depend_prior[depend] = pair(prior, behind)
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> depend_prior;
// InitDependPrior(todos, &depend_prior);
for (auto iter = todos.cbegin(); iter != todos.cend(); ++iter) {
auto node = (*iter)->cast<CNodePtr>();
if (node == nullptr || IsKeepBasicNode(node) || fused_ops->count(node)) {
continue;
}
bool is_fusible_op = IsFusibleOp(node);
if (!is_fusible_op || !kernel_graph->nodes().contains(node)) {
continue;
}
auto fuse_nodes = FindFuseCNodes(node, depend_prior);
if (fuse_nodes.empty()) {
continue;
}
if (fuse_nodes.size() == 1) {
// Do not fuse a single GraphKernel again.
// Do not fuse a single Assign.
if (AnfAlgo::IsGraphKernel(fuse_nodes[0]) || IsPrimitiveCNode(fuse_nodes[0], prim::kPrimAssign)) {
continue;
}
}
changed = true;
fused_ops->insert(fuse_nodes.begin(), fuse_nodes.end());
AnfNodePtr fused_new_node;
AnfNodePtrList old_outputs;
std::tie(fused_new_node, old_outputs) = FuseNodesToSubGraph(fuse_nodes, kernel_graph, "fusion");
ReplaceNewFuseCNodeForDependPrior(&depend_prior, fused_new_node, old_outputs);
}
std::dynamic_pointer_cast<session::KernelGraph>(kernel_graph)->SetExecOrderByDefault();
return changed;
}
} // namespace
bool FuseBasicOps(const FuncGraphPtr &func_graph) {
std::unordered_set<AnfNodePtr> fused_ops;
auto todos = TopoSort(func_graph->get_return());
std::reverse(todos.begin(), todos.end());
return FuseBasicOps(func_graph, todos, &fused_ops);
}
void EliminateGetitem(const FuncGraphPtr &func_graph) {
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
auto todos = TopoSort(func_graph->get_return());
for (auto node : todos) {
if (AnfAlgo::IsGraphKernel(node)) {
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(node));
}
}
}
bool BasicOpsFusion::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = FuseBasicOps(func_graph);
if (changed) {
EliminateGetitem(func_graph);
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -1,36 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
bool FuseBasicOps(const FuncGraphPtr &kernel_graph);
class BasicOpsFusion : public Pass {
public:
BasicOpsFusion() : Pass("basic_ops_fusion") {}
~BasicOpsFusion() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
};
using FuseBasicPtr = std::shared_ptr<BasicOpsFusion>;
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_BASIC_OPS_FUSION_H_

View File

@ -18,6 +18,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
namespace mindspore {
namespace opt {
@ -99,7 +100,7 @@ bool CastMatmulFusion::Run(const FuncGraphPtr &func_graph) {
continue;
}
// Cast cannot fuse with its input
if (IsFusibleOp((cast_node->cast<CNodePtr>())->input(1))) {
if (IsClusterableOp((cast_node->cast<CNodePtr>())->input(1))) {
continue;
}

View File

@ -1,217 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/composite_ops_fusion.h"
#include <algorithm>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <set>
#include <unordered_set>
#include <utility>
#include <vector>
#include "frontend/operator/ops.h"
#include "utils/utils.h"
#include "utils/ordered_set.h"
#include "utils/ordered_map.h"
#include "ir/graph_utils.h"
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "vm/segment_runner.h"
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
namespace opt {
namespace {
std::vector<AnfNodePtr> DeepLinkedGraphSearch(const std::vector<AnfNodePtr> &roots, const IncludeFunc &include) {
std::vector<AnfNodePtr> inputs;
for (auto &root : roots) {
auto tmp = DeepLinkedGraphSearch(root, include);
inputs.insert(inputs.end(), tmp.begin(), tmp.end());
}
return inputs;
}
} // namespace
bool CheckCircle(const std::set<AnfNodePtr> &fused_op_set, const AnfNodePtr &check_node,
std::set<AnfNodePtr> *cached_unconnected_set, std::vector<AnfNodePtr> *circle_nodes,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
if (!check_node->isa<CNode>() || !fused_op_set.count(check_node)) {
return false;
}
circle_nodes->clear();
auto InputEdges = [&depend_prior](const CNodePtr &cnode) {
std::set<AnfNodePtr> edges;
auto range = depend_prior.equal_range(cnode);
for (auto iter = range.first; iter != range.second; ++iter) {
edges.insert(iter->second.first);
}
auto inputs = cnode->inputs();
for (auto input : inputs) {
edges.insert(input);
}
return edges;
};
// consider prior depend both in fused_op_set
auto range = depend_prior.equal_range(check_node);
for (auto iter = range.first; iter != range.second; ++iter) {
if (fused_op_set.count(iter->second.first)) {
circle_nodes->push_back(iter->second.first);
}
}
std::set<AnfNodePtr> cached_done_set;
auto cnode = check_node->cast<CNodePtr>();
const auto &inputs = InputEdges(cnode);
// there is a input not in fused_op_set, but the input depends on the fused_op_set
for (auto input : inputs) {
if (input->isa<CNode>() && !fused_op_set.count(input)) {
bool has_circle = false;
std::set<AnfNodePtr> done;
std::vector<AnfNodePtr> todos = {input};
while (!todos.empty()) {
auto node = todos.back();
todos.pop_back();
if (done.count(node) || cached_unconnected_set->count(node) || cached_done_set.count(node)) {
continue;
}
done.insert(node);
if (fused_op_set.count(node)) {
has_circle = true;
circle_nodes->push_back(node);
continue;
}
if (node->isa<CNode>()) {
auto cnode_ptr = node->cast<CNodePtr>();
for (auto it : InputEdges(cnode_ptr)) {
if (it->isa<CNode>()) {
todos.push_back(it);
}
}
}
}
if (has_circle) {
cached_done_set.insert(done.begin(), done.end());
} else {
cached_unconnected_set->insert(done.begin(), done.end());
}
done.clear();
}
}
return !circle_nodes->empty();
}
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior) {
std::set<AnfNodePtr> cached_unconnected_set;
std::set<AnfNodePtr> fused_op_set(fused_op.begin(), fused_op.end());
auto include = [&fused_op_set](const AnfNodePtr &node) {
if (fused_op_set.count(node)) {
return FOLLOW;
}
return EXCLUDE;
};
std::vector<AnfNodePtr> circle_nodes;
for (auto iter = fused_op.rbegin(); iter != fused_op.rend(); ++iter) {
circle_nodes.clear();
bool has_circle = CheckCircle(fused_op_set, *iter, &cached_unconnected_set, &circle_nodes, depend_prior);
// delete the circle node and the node which depend on the circle node in fused op
if (has_circle) {
std::vector<AnfNodePtr> erase_nodes;
erase_nodes = DeepLinkedGraphSearch(circle_nodes, include);
for (auto erase_node : erase_nodes) {
fused_op_set.erase(erase_node);
}
}
}
std::vector<AnfNodePtr> res;
for (auto node : fused_op) {
if (fused_op_set.count(node)) {
res.push_back(node);
}
}
return res;
}
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst) {
if (lst->size() < 2) {
return;
}
std::vector<AnfNodePtr> res;
std::set<AnfNodePtr> node_sets(lst->begin(), lst->end());
OrderedMap<AnfNodePtr, std::set<AnfNodePtr>> ins;
OrderedMap<AnfNodePtr, OrderedSet<AnfNodePtr>> outs;
std::queue<AnfNodePtr> q;
for (auto node : *lst) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (auto input : cnode->inputs()) {
if (!node_sets.count(input)) {
continue;
}
// out_degree
outs[input].insert(node);
// in_degree
ins[node].insert(input);
}
if (!ins.count(node)) {
ins[node] = {};
}
}
for (auto p : ins) {
if (p.second.size() == 0) {
q.push(p.first);
}
}
while (!q.empty()) {
auto node = q.front();
q.pop();
res.push_back(node);
if (!outs.count(node)) {
continue;
}
for (auto out : outs[node]) {
if (!ins.count(out)) {
continue;
}
ins[out].erase(node);
if (ins[out].size() == 0) {
q.push(out);
}
}
}
lst->assign(res.begin(), res.end());
}
} // namespace opt
} // namespace mindspore

View File

@ -1,37 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_
#include <limits>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
namespace mindspore {
namespace opt {
AnfNodePtrList RemoveCircle(const std::vector<AnfNodePtr> &fused_op,
const std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> &depend_prior);
void TopoSortForNodeList(std::vector<AnfNodePtr> *lst);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_COMPOSITE_OPS_FUSION_H_

View File

@ -0,0 +1,477 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
#include <algorithm>
#include <map>
#include <unordered_map>
#include <set>
#include <vector>
#include <memory>
#include <utility>
#include <fstream>
#include "base/core_ops.h"
#include "ir/graph_utils.h"
#include "debug/common.h"
#include "utils/context/graph_kernel_flags.h"
#include "backend/kernel_compiler/common_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/optimizer/pass/getitem_tuple.h"
namespace mindspore {
namespace opt {
namespace {
std::vector<PrimitivePtr> GetClusterableOpList() {
std::vector<PrimitivePtr> clusterable_ops = {
prim::kPrimAbs,
prim::kPrimRound,
prim::kPrimNeg,
prim::kPrimExp,
prim::kPrimAdd,
prim::kPrimCast,
prim::kPrimMul,
prim::kPrimMinimum,
prim::kPrimMaximum,
prim::kPrimLog,
prim::kPrimPow,
prim::kPrimSub,
prim::kPrimRsqrt,
prim::kPrimSqrt,
prim::kPrimAddN,
prim::kPrimReciprocal,
prim::kPrimTanh,
prim::kPrimReshape,
prim::kPrimTranspose,
prim::kPrimRealDiv,
prim::kPrimReduceSum,
prim::kPrimEqual,
prim::kPrimAssign,
prim::kPrimInplaceAssign,
#if ENABLE_D
prim::kPrimMatMul,
prim::KPrimTransData,
#elif ENABLE_GPU
prim::kPrimReduceMax,
prim::kPrimReduceMin,
prim::kPrimGreater,
prim::kPrimLess,
prim::kPrimGreaterEqual,
prim::kPrimLessEqual,
prim::kPrimSelect,
#endif
};
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&clusterable_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
return clusterable_ops;
}
size_t CountGraphKernelInnerNodes(const AnfNodePtr &node) {
AnfNodePtrList node_list;
kernel::GetValidKernelNodes(AnfAlgo::GetCNodeFuncGraphPtr(node), &node_list);
return node_list.size();
}
} // namespace
bool IsClusterableOp(const AnfNodePtr &node) {
if (IsKeepBasicNode(node)) {
return false;
}
if (AnfAlgo::IsGraphKernel(node)) {
return true;
}
auto op_list = GetClusterableOpList();
bool node_in_oplist = std::any_of(op_list.begin(), op_list.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
if (!node_in_oplist) {
return false;
}
#if ENABLE_D
// For AICPU operators, only the Reshape can be clustered.
if (AnfAlgo::GetProcessor(node) != kernel::Processor::AICORE && !IsPrimitiveCNode(node, prim::kPrimReshape)) {
return false;
}
#endif
return true;
}
class Graph {
struct Cluster {
size_t cluster_id_; // node_id of the representative.
size_t cluster_size_{1}; // size of cluster, composite node is considered as one node.
size_t basic_op_cnt_{1}; // basic node count, the inner nodes of composite node are counted.
std::set<size_t> inputs_; // inputs' cluster_id.
size_t seed_{0}; // visited flag of dfs.
Cluster(size_t node_id, const AnfNodePtr &node, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map)
: cluster_id_(node_id) {
if (IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
basic_op_cnt_ = 0;
} else if (AnfAlgo::IsGraphKernel(node)) {
// the basic_op_cnt_ is used to limit the composite op size
basic_op_cnt_ = CountGraphKernelInnerNodes(node);
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
for (const auto &inp : cnode->inputs()) {
auto iter = node_idx_map.find(inp);
if (iter != node_idx_map.end()) {
// At the beginning, cluster_id is equal to node_id
inputs_.insert(iter->second);
}
}
}
~Cluster() = default;
void Merge(Cluster *other_cluster) {
other_cluster->cluster_id_ = cluster_id_;
cluster_size_ += other_cluster->cluster_size_;
basic_op_cnt_ += other_cluster->basic_op_cnt_;
std::for_each(other_cluster->inputs_.begin(), other_cluster->inputs_.end(),
[this](size_t inp) { this->inputs_.insert(inp); });
other_cluster->Clean();
}
// clean the info to free memory.
void Clean() {
inputs_.clear();
cluster_size_ = 0;
basic_op_cnt_ = 0;
}
}; // struct Cluster
public:
// Init and build graph
Graph(const AnfNodePtrList &nodes, const std::unordered_map<AnfNodePtr, size_t> &node_idx_map) {
clusters_.reserve(nodes.size());
for (size_t i = 0; i < nodes.size(); i++) {
clusters_.emplace_back(i, nodes[i], node_idx_map);
}
}
~Graph() = default;
// find the representative of the cluster
int Find(size_t node_id) {
size_t &pre_id = clusters_[node_id].cluster_id_;
return (pre_id == clusters_[pre_id].cluster_id_) ? pre_id : (pre_id = Find(pre_id));
}
// merge clusters, the smallest cluster id will be the new cluster id.
void Merge(const std::set<size_t> &candidates) {
for (auto iter = ++candidates.begin(); iter != candidates.end(); ++iter) {
clusters_[*candidates.begin()].Merge(&clusters_[*iter]);
}
}
// Collect nodes together that are in the same cluster.
std::vector<std::vector<size_t>> CollectClusters() {
std::vector<std::vector<size_t>> cluster_map(clusters_.size());
for (size_t i = 0; i < clusters_.size(); i++) {
cluster_map[Find(i)].push_back(i);
}
return cluster_map;
}
using VisitFunc = std::function<IncludeType(size_t)>;
void Dfs(size_t node_id, VisitFunc visitor) {
++seen_;
return DepthFirstSearch(Find(node_id), visitor);
}
// Get cluster size
size_t GetSize(size_t cluster_id) { return clusters_[Find(cluster_id)].cluster_size_; }
// Get cluster's basic op count
size_t GetBasicNodeCount(size_t cluster_id) { return clusters_[Find(cluster_id)].basic_op_cnt_; }
// Get cluster's inputs
const std::set<size_t> &GetInputs(size_t cluster_id) {
cluster_id = Find(cluster_id);
RefreshInputs(cluster_id);
return clusters_[cluster_id].inputs_;
}
private:
void RefreshInputs(size_t i) {
auto &inputs = clusters_[i].inputs_;
for (auto iter = inputs.begin(); iter != inputs.end();) {
size_t new_id = Find(*iter);
if (new_id != *iter) {
iter = inputs.erase(iter);
inputs.insert(new_id);
} else {
++iter;
}
}
inputs.erase(i);
}
void DepthFirstSearch(size_t cluster_id, const VisitFunc &visitor) {
if (clusters_[cluster_id].seed_ >= seen_) return;
clusters_[cluster_id].seed_ = seen_;
if (visitor(cluster_id) != FOLLOW) {
return;
}
// traverse inputs in descending order.
const auto &inputs = GetInputs(cluster_id);
for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) {
DepthFirstSearch(*iter, visitor);
}
}
std::vector<Cluster> clusters_;
size_t seen_{0};
}; // class Graph
class CircleChecker {
public:
explicit CircleChecker(GraphPtr graph) : graph_(graph) {}
~CircleChecker() = default;
void RemoveCircle(std::set<size_t> *candidates) {
if (candidates->size() <= 1) {
return;
}
candidates_ = candidates;
std::vector<size_t> tmp_list(candidates->begin(), candidates->end());
for (auto c : tmp_list) {
if (!candidates->count(c)) continue;
circle_nodes_.clear();
if (CheckCircle(c)) {
RemoveCircleNodesFromCandidates();
}
}
}
private:
/**
* Check circle. the candidate is collected into circle_nodes_ if it will form a circle.
*
* algorithm:
* Search from the basenode's input that is NOT in candidates (the basenode is a candidate),
* If it depends on a node that belongs to candidates, it will form a circle.
* e.g. A -> x -> ... -> B
* -> y -> ... -> C
* In this case, A, B and C are candidates while x and y are not.
* Both x and y are inputs of A. assumes A is the basenode.
* When searching from x, the B will be found and added into circle_nodes list,
* and then when searching from y, the C will be found and added into circle_nodes list.
*/
bool CheckCircle(size_t basenode) {
const auto &inputs = graph_->GetInputs(basenode);
std::set<size_t> visited_circle_nodes;
for (auto x : inputs) {
if (candidates_->count(x)) continue;
bool has_circle = false;
std::set<size_t> done;
auto vis_func = [this, &has_circle, &done, &visited_circle_nodes](size_t node_id) {
if (done.count(node_id) || acyclic_nodes_.count(node_id) || visited_circle_nodes.count(node_id)) {
return EXCLUDE;
}
done.insert(node_id);
if (candidates_->count(node_id)) {
has_circle = true;
circle_nodes_.push_back(node_id);
return EXCLUDE;
}
// all nodes are indexed by topo order,
// so if the node_id is less than the minimal candidate, a cycle cannot be formed from this node.
if (candidates_->empty() || node_id < *candidates_->begin()) {
return EXCLUDE;
}
return FOLLOW;
};
graph_->Dfs(x, vis_func);
if (has_circle) {
visited_circle_nodes.insert(done.begin(), done.end());
} else {
acyclic_nodes_.insert(done.begin(), done.end());
}
}
return !circle_nodes_.empty();
}
// remove all circle nodes from candidates
void RemoveCircleNodesFromCandidates() {
auto remove_from_candidates = [this](size_t node_id) {
if (candidates_->count(node_id)) {
candidates_->erase(node_id);
return FOLLOW;
}
return EXCLUDE;
};
for (auto node : circle_nodes_) {
graph_->Dfs(node, remove_from_candidates);
}
}
private:
GraphPtr graph_; // bind the global graph
std::set<size_t> *candidates_{nullptr}; // bind the input candidates
std::vector<size_t> circle_nodes_;
std::set<size_t> acyclic_nodes_;
}; // CircleChecker
std::set<size_t> GraphKernelCluster::FindCandidates(size_t basenode_id) {
std::set<size_t> candidates;
auto include = [this, &candidates, func_graph = nodes_[basenode_id]->func_graph()](size_t cluster_id) {
const AnfNodePtr &node = this->nodes_[cluster_id];
if (node->func_graph() != func_graph) {
return EXCLUDE;
}
if (!IsClusterableOp(node) && !IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) {
return EXCLUDE;
}
candidates.insert(cluster_id);
// Do not search from clustered node again.
if (this->graph_->GetSize(cluster_id) > 1) {
return NOFOLLOW;
}
return FOLLOW;
};
graph_->Dfs(basenode_id, include);
return candidates;
}
bool GraphKernelCluster::Process(const FuncGraphPtr &func_graph) {
bool changed = false;
for (int i = nodes_.size() - 1; i >= 0; i--) {
// if the node has been clustered, it has tried to find its previous nodes, so it's unnecessary to try again.
if (graph_->GetSize(i) > 1) {
continue;
}
auto candidates = FindCandidates(i);
CircleChecker(graph_).RemoveCircle(&candidates);
RemoveWildGetitem(&candidates);
if (candidates.empty()) continue;
// merge candidates into one cluster
graph_->Merge(candidates);
}
// Rebuild func_graphs
auto clusters = graph_->CollectClusters();
for (size_t i = 0; i < clusters.size(); i++) {
auto node_without_getitem = std::count_if(clusters[i].begin(), clusters[i].end(), [this](size_t node_id) {
return !IsPrimitiveCNode(this->nodes_[node_id], prim::kPrimTupleGetItem);
});
if (node_without_getitem == 0) continue;
if (node_without_getitem == 1) {
// Do not cluster a single GraphKernel again.
// Do not cluster a single Assign.
const auto &node = nodes_[clusters[i][0]];
if (AnfAlgo::IsGraphKernel(node) || IsPrimitiveCNode(node, prim::kPrimAssign) || !IsClusterableOp(node)) {
continue;
}
}
CreateFuncGraph(func_graph, clusters[i]);
changed = true;
}
return changed;
}
void GraphKernelCluster::CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id) {
AnfNodePtrList old_nodes;
AnfNodePtr new_node;
std::transform(nodes_id.begin(), nodes_id.end(), std::back_inserter(old_nodes),
[this](size_t id) { return this->nodes_[id]; });
std::tie(new_node, std::ignore) = FuseNodesToSubGraph(old_nodes, func_graph, "fusion");
std::shared_ptr<Pass> eliminate_getitem_pass = std::make_shared<opt::GetitemTuple>();
eliminate_getitem_pass->Run(AnfAlgo::GetCNodeFuncGraphPtr(new_node));
if (context::GraphKernelFlags::GetInstance().dump_as_text) {
DumpClusterInfo(old_nodes, new_node);
}
}
void GraphKernelCluster::DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node) {
#ifdef ENABLE_DUMP_IR
dump_buf_ << "Source nodes of " << new_node->fullname_with_scope() << " = " << new_node->DebugString() << std::endl;
for (const auto &node : old_nodes) {
dump_buf_ << " " << node->fullname_with_scope() << " = " << node->DebugString() << std::endl;
}
dump_buf_ << "=======================" << std::endl;
#endif
}
void GraphKernelCluster::DumpToFile() {
#ifdef ENABLE_DUMP_IR
auto pathname = std::string("./") + kGraphKernelDumpPath + "/graph_kernel_cluster.txt";
auto realpath = Common::GetRealPath(pathname);
if (!realpath.has_value()) {
MS_LOG(ERROR) << "Get real path failed. path=" << pathname;
return;
}
std::ofstream fout(realpath.value(), std::ios::app);
if (!fout.is_open()) {
MS_LOG(ERROR) << "Open dump file '" << realpath.value() << "' failed!";
return;
}
fout << dump_buf_.str() << std::endl;
fout.close();
#endif
}
// The GetItem node should be clustered with its real input.
// If its real input is not in the candidates, the GetItem should be excluded.
void GraphKernelCluster::RemoveWildGetitem(std::set<size_t> *candidates) {
for (auto iter = candidates->begin(); iter != candidates->end();) {
size_t cluster_id = *iter;
/*The implied condition is graph->GetSize(cluster_id) == 1*/
if (IsPrimitiveCNode(nodes_[cluster_id], prim::kPrimTupleGetItem)) {
const auto &inputs = graph_->GetInputs(cluster_id);
if (inputs.size() != 1) {
MS_LOG(ERROR) << "Input size of GetItem(" << cluster_id << ") should be 1, but got " << inputs.size();
candidates->clear();
return;
}
auto prev_id = *(inputs.begin());
if (!candidates->count(prev_id)) {
iter = candidates->erase(iter);
continue;
}
}
++iter;
}
}
void GraphKernelCluster::Init(const FuncGraphPtr &func_graph) {
// process cnode only
nodes_ = TopoSort(func_graph->get_return(), SuccIncoming,
[](const AnfNodePtr &node) { return node->isa<CNode>() ? FOLLOW : EXCLUDE; });
for (size_t i = 0; i < nodes_.size(); i++) {
node_idx_map_[nodes_[i]] = i;
}
graph_ = std::make_shared<Graph>(nodes_, node_idx_map_);
MS_EXCEPTION_IF_NULL(graph_);
}
bool GraphKernelCluster::Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
Init(func_graph);
bool changed = Process(func_graph);
if (changed) {
if (context::GraphKernelFlags::GetInstance().dump_as_text) {
DumpToFile();
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}
Clean();
return changed;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_
#include <vector>
#include <string>
#include <unordered_map>
#include <set>
#include <memory>
#include "ir/anf.h"
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class Graph;
using GraphPtr = std::shared_ptr<Graph>;
class GraphKernelCluster : public Pass {
public:
GraphKernelCluster() : Pass("graph_kernel_cluster") {}
~GraphKernelCluster() override = default;
bool Run(const FuncGraphPtr &func_graph) override;
private:
void Init(const FuncGraphPtr &func_graph);
bool Process(const FuncGraphPtr &func_graph);
std::set<size_t> FindCandidates(size_t basenode_id);
void RemoveWildGetitem(std::set<size_t> *candidates);
void CreateFuncGraph(const FuncGraphPtr &func_graph, const std::vector<size_t> &nodes_id);
void DumpClusterInfo(const AnfNodePtrList &old_nodes, const AnfNodePtr &new_node);
void DumpToFile();
void Clean() {
std::vector<AnfNodePtr>().swap(nodes_);
node_idx_map_.clear();
graph_ = nullptr;
}
GraphPtr graph_{nullptr};
std::vector<AnfNodePtr> nodes_;
std::unordered_map<AnfNodePtr, size_t> node_idx_map_;
std::stringstream dump_buf_;
};
bool IsClusterableOp(const AnfNodePtr &node);
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_GRAPH_KERNEL_CLUSTER_H_

View File

@ -593,77 +593,6 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
return name.str();
}
std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimInplaceAssign,
prim::KPrimTransData};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimAssign, prim::kPrimLessEqual, prim::kPrimGreaterEqual, prim::kPrimReduceMax, prim::kPrimReduceMin,
prim::kPrimLess, prim::kPrimInplaceAssign};
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif
const auto &flags = context::GraphKernelFlags::GetInstance();
OpListFilter(&fusible_basic_ops, flags.enable_cluster_ops_only, flags.enable_cluster_ops, flags.disable_cluster_ops);
return fusible_basic_ops;
}
bool CheckProcessor(const AnfNodePtr &node, kernel::Processor processor = kernel::Processor::AICORE) {
MS_EXCEPTION_IF_NULL(node);
auto node_kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
if (node_kernel_info == nullptr) {
return false;
}
auto node_build_info = node_kernel_info->GetMutableSelectKernelBuildInfo();
if (node_build_info == nullptr) {
return false;
}
return node_build_info->processor() == processor;
}
bool IsBasicFuseOp(const AnfNodePtr &node) {
std::vector<PrimitivePtr> basic_ops = GetFusibleOpList();
#if ENABLE_D
if (!CheckProcessor(node)) {
std::vector<PrimitivePtr> fused_aicpu_op = {prim::kPrimExpandDims, prim::kPrimReshape};
if (!std::any_of(fused_aicpu_op.begin(), fused_aicpu_op.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); })) {
return false;
}
}
#endif
return std::any_of(basic_ops.begin(), basic_ops.end(),
[&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
}
bool IsFusibleOp(const AnfNodePtr &node) {
#if ENABLE_D
const std::set<std::string> graph_kernel_black_list = {"BNTrainingUpdateSum", "ApplyMomentum", "LayerNormForward",
"LambNextMV", "LambUpdateWithLR"};
if (AnfAlgo::IsGraphKernel(node)) {
auto fg_attr = AnfAlgo::GetCNodeFuncGraphPtr(node)->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL);
if (fg_attr != nullptr) {
return graph_kernel_black_list.count(GetValue<std::string>(fg_attr)) == 0;
}
}
#endif
return IsBasicFuseOp(node) || AnfAlgo::IsGraphKernel(node);
}
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
@ -674,37 +603,6 @@ void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type) {
#endif
}
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs) {
std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> new_fuse_cnode_dep_pri;
for (size_t out_idx = 0; out_idx < outputs.size(); ++out_idx) {
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimMakeTuple)) {
MS_LOG(ERROR) << "Need real outputs of makeTuple";
}
if (IsPrimitiveCNode(outputs[out_idx], prim::kPrimTupleGetItem)) {
continue;
}
for (auto iter = (*depend_prior).begin(); iter != (*depend_prior).end();) {
if (iter->first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({new_fuse_cnode, iter->second});
iter = depend_prior->erase(iter);
continue;
}
if (iter->second.first == outputs[out_idx]) {
new_fuse_cnode_dep_pri.insert({iter->first, std::make_pair(new_fuse_cnode, iter->second.second)});
iter = depend_prior->erase(iter);
continue;
}
++iter;
}
}
for (auto item : new_fuse_cnode_dep_pri) {
depend_prior->insert(item);
}
}
std::string GetFormat(const AnfNodePtr &node) {
auto kernel_info = static_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);

View File

@ -47,6 +47,8 @@ constexpr auto kJsonKeyMultiGraph = "multi_graph";
constexpr auto kJsonKeyGraphDesc = "graph_desc";
constexpr auto kJsonKeyGraphMode = "graph_mode";
constexpr auto kGraphKernelDumpPath = "graph_kernel_dump";
struct DataInfo {
std::string format{kOpFormat_DEFAULT};
ShapeVector shape{1};
@ -75,12 +77,7 @@ bool AnfToJsonDesc(const AnfNodePtrList &nodes, const DumpOption &dump_option, n
bool AnfToJsonDesc(const std::vector<AnfNodePtrList> &graphs, const DumpOption &dump_option, nlohmann::json *op_desc);
FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNodePtr> &inputs);
std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &prefix = "", const string &postfix = "");
std::vector<PrimitivePtr> GetFusibleOpList();
bool IsBasicFuseOp(const AnfNodePtr &node);
bool IsFusibleOp(const AnfNodePtr &node);
void ResetKernelInfo(const AnfNodePtr &node, KernelType kernel_type = KernelType::UNKNOWN_KERNEL_TYPE);
void ReplaceNewFuseCNodeForDependPrior(std::multimap<AnfNodePtr, std::pair<AnfNodePtr, AnfNodePtr>> *depend_prior,
const AnfNodePtr &new_fuse_cnode, const AnfNodePtrList &outputs);
std::string GetFormat(const AnfNodePtr &node);
TypePtr GetType(const AnfNodePtr &node);

View File

@ -25,7 +25,7 @@
#include "backend/optimizer/graph_kernel/add_atomic_clean.h"
#include "backend/optimizer/graph_kernel/add_stitch_atomic_clean_gpu.h"
#include "backend/optimizer/graph_kernel/arithmetic_simplify.h"
#include "backend/optimizer/graph_kernel/basic_ops_fusion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_cluster.h"
#include "backend/optimizer/graph_kernel/eliminate_redundant_output.h"
#include "backend/optimizer/graph_kernel/tensor_promotion.h"
#include "backend/optimizer/graph_kernel/graph_kernel_splitter.h"
@ -65,8 +65,8 @@ PassManagerPtr GraphKernelOptimizer::Cluster() const {
// Expand complex basic kernels to composite kernels
pm->AddPass(std::make_shared<GraphKernelExpander>());
// Fuse basic kernels and composite kernels
pm->AddPass(std::make_shared<BasicOpsFusion>());
// Cluster basic kernels and composite kernels
pm->AddPass(std::make_shared<GraphKernelCluster>());
// Eliminate the outputs without external user
pm->AddPass(std::make_shared<EliminateRedundantOutput>());