!25614 [GraphKernel] Enable parallel fusion in Ascend and enhance parallel feature.

Merge pull request !25614 from TronZhang/parallel_support_in_ascend
This commit is contained in:
i-robot 2021-11-12 09:44:03 +00:00 committed by Gitee
commit 5211733add
8 changed files with 137 additions and 13 deletions

View File

@ -290,6 +290,12 @@ def block_parallel_estimate(graphs):
return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
def parallel_estimate(graphs):
def parallel_estimate(graphs, target):
"""Estimate parallel gain"""
if target == "aicore":
fusion_type = "block_fusion"
type_info = None
fake_estimate = 1000
fake_blocks = [1 for g in graphs]
return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info)
return block_parallel_estimate(graphs)

View File

@ -20,15 +20,21 @@ from mindspore import log as logger
from . import model
def estimate_ops(json_str: str):
def estimate_ops(json_str):
"""Call cost model to estimate ops."""
try:
json_obj = json.loads(json_str)
graph_descs = json_obj["graph_desc"]
graphs = []
target = None
for gd in graph_descs:
if target is None:
target = gd['process']
elif target != gd['process']:
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
return None
graphs.append(model.load_composite(gd).graph)
estimation = model.parallel_estimate(graphs)
estimation = model.parallel_estimate(graphs, target)
res = (estimation.block_assign, estimation.gain,
estimation.fusion_type, estimation.type_info)
return res
@ -37,12 +43,13 @@ def estimate_ops(json_str: str):
return None
def estimate_calulation_amount(json_str: str):
def estimate_calulation_amount(json_str):
"""Call cost model to estimate calculation amount of op."""
try:
graph_desc = json.loads(json_str)
target = graph_desc['process']
comp = model.load_composite(graph_desc)
estimation = model.parallel_estimate([comp.graph])
estimation = model.parallel_estimate([comp.graph], target)
return estimation.bottleneck
except jd.JSONDecodeError:
logger.error(traceback.format_exc())

View File

@ -56,6 +56,7 @@ using opt::GetitemTuple;
using opt::GraphOptimizer;
namespace {
auto constexpr PARALLEL_OPS_LIMIT = 7;
inline unsigned int GetPassLevelByFlag(bool flag) { return flag ? OptLevel_1 : OptLevel_MAX; }
} // namespace
@ -175,8 +176,14 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
PassManagerPtr GraphKernelOptimizer::Combine() const {
auto pm = std::make_shared<GraphKernelPassManager>(5, "combine");
// Enable parallel fusion for gpu device
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
auto level = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_parallel_fusion);
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)), level, is_gpu);
// Atomic-add GraphKernel node may be linked directly to UpdateState, it should be spread before parallel fusion!
pm->AddPass(std::make_shared<SpreadUpdateState>(), level);
pm->AddPass(std::make_shared<ParallelOpFusion>(target, ParallelConfig(PARALLEL_OPS_LIMIT)), level,
is_gpu || is_ascend);
return pm;
}

View File

@ -112,8 +112,8 @@ FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type
}
ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) const {
if (target != kGPUDevice) {
MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now.";
if (target != kGPUDevice && target != kAscendDevice) {
MS_LOG(EXCEPTION) << "Parallel cost model do not support " << target << " now.";
}
return cost_model_;
}

View File

@ -19,9 +19,12 @@
#include <algorithm>
#include <list>
#include <queue>
#include <unordered_map>
#include <utility>
#include "utils/context/graph_kernel_flags.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
#include "backend/kernel_compiler/common_utils.h"
#include "frontend/operator/ops.h"
#include "ir/func_graph_cloner.h"
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
@ -269,6 +272,32 @@ bool WhiteOpsFilter(const AnfNodePtr &node) {
return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
}
bool Unfavorable(const AnfNodePtr &node) {
// Parallel cannot work with stitching for now.
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input = cnode->input(kAnfPrimitiveIndex);
if (!IsValueNode<FuncGraph>(input)) {
return AnfAlgo::HasNodeAttr(kAttrStitch, cnode);
}
auto func_graph = GetValueNode<FuncGraphPtr>(input);
MS_EXCEPTION_IF_NULL(func_graph);
AnfNodePtrList sub_nodes;
kernel::GetValidKernelNodes(func_graph, &sub_nodes);
for (auto sub_node : sub_nodes) {
auto sub_cnode = sub_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(sub_cnode);
if (AnfAlgo::HasNodeAttr(kAttrStitch, sub_cnode)) {
return true;
}
}
return false;
}
bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !Unfavorable(node); }
std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
const std::function<bool(const AnfNodePtr &)> &filter_func,
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
@ -320,7 +349,7 @@ void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
if (auto iter = node_rels.find(node); iter != node_rels.end()) {
const auto &pre_nodes = get_related_nodes(iter->second);
AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
}
}
@ -337,7 +366,7 @@ void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
// Erase empty groups.
for (auto iter = groups->begin(); iter != groups->end();) {
@ -358,8 +387,9 @@ std::string DumpNode(const AnfNodePtr &node) {
return buf.str();
}
void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) {
MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: ";
void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
MS_LOG(INFO) << "[" << title << "]"
<< "There are " << groups.size() << " parallel groups, their detail is: ";
int i = 0;
for (const auto group : groups) {
std::stringstream buf;
@ -465,7 +495,7 @@ std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
DumpParallelGroups(groups);
DumpParallelGroups(groups, "Dependency Analyze");
return groups;
}
@ -725,8 +755,63 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
return changed;
}
std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
std::set<AnfNodePtr> captured;
std::for_each(infos.cbegin(), infos.cend(),
[&captured](const ParallelInfo &info) { captured.insert(info.nodes().begin(), info.nodes().end()); });
return captured;
}
std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
const std::set<AnfNodePtr> &exclude) {
std::vector<std::vector<AnfNodePtrList>> groups;
// BFS
std::queue<AnfNodePtr> node_que;
std::unordered_map<AnfNodePtr, int> outdegrees;
for (const auto &[node, ref] : node_rels) {
outdegrees[node] = ref.nexts.size();
if (outdegrees[node] == 0) {
node_que.push(node);
}
}
int total_node_num = node_rels.size();
while (!node_que.empty()) {
std::vector<AnfNodePtrList> group;
int node_size = node_que.size();
while (node_size--) {
auto node = node_que.front();
node_que.pop();
if (exclude.count(node) == 0 && Parallelizable(node)) {
group.push_back({node});
}
--total_node_num;
auto iter = node_rels.find(node);
if (iter == node_rels.end()) {
MS_LOG(EXCEPTION) << "Internal error in node relationship!";
}
for (const auto &pre : iter->second.pres) {
if (--outdegrees[pre] == 0) {
node_que.push(pre);
}
}
}
if (!group.empty()) {
groups.push_back(group);
}
}
if (total_node_num > 0) {
MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
}
DumpParallelGroups(groups, "BFS");
return groups;
}
bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
(void)std::make_shared<ShrinkUpdateState>()->Run(graph);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -741,6 +826,15 @@ bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
auto groups = SearchParallelGroups(node_rels);
auto parallel_infos = SearchFusableParallelCNodes(groups);
// Search in BFS for left nodes.
if (parallel_level_ > 0) {
auto exclued_nodes = CollectCapturedNodes(parallel_infos);
auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
parallel_infos.insert(parallel_infos.end(), bfs_parallel_infos.begin(), bfs_parallel_infos.end());
}
// Create core-fuse subgraph and change origin graph.
bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
(void)std::make_shared<SpreadUpdateState>()->Run(graph);

View File

@ -120,6 +120,7 @@ class ParallelOpFusion : public opt::Pass {
ParallelCostModelPtr cost_model_ptr_;
std::set<AnfNodePtr> virtual_noout_nodes_;
std::set<AnfNodePtr> ignore_noin_nodes_;
unsigned int parallel_level_{0};
};
using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
} // namespace mindspore::graphkernel

View File

@ -213,6 +213,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
// Integer flags
reg.AddFlag("online_tuning", &online_tuning);
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_MAX);
reg.AddFlag("parallel_ops_level", &parallel_ops_level);
// String flags
reg.AddFlag("repository_path", &repository_path);
@ -242,6 +243,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
json["opt_level"] = opt_level;
json["fusion_ops_level"] = fusion_ops_level;
json["parallel_ops_level"] = parallel_ops_level;
json["online_tuning"] = online_tuning;
json["repository_path"] = repository_path;

View File

@ -83,6 +83,13 @@ class GraphKernelFlags {
*/
bool enable_parallel_fusion{false};
/**
* Parallel AKG's operators by level.
* 0: Parallel operators by local data relation analyzation with less memory influence.
* 1: Parallel operators with global analyzation with more memory influence.
*/
unsigned int parallel_ops_level{OpLevel_0};
/**
* Enable low precision in data transferring between graph kernel and computing in graph kernel
* in graph kernel.