forked from mindspore-Ecosystem/mindspore
!25614 [GraphKernel] Enable parallel fusion in Ascend and enhance parallel feature.
Merge pull request !25614 from TronZhang/parallel_support_in_ascend
This commit is contained in:
commit
5211733add
|
@ -290,6 +290,12 @@ def block_parallel_estimate(graphs):
|
|||
return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
|
||||
|
||||
|
||||
def parallel_estimate(graphs):
|
||||
def parallel_estimate(graphs, target):
|
||||
"""Estimate parallel gain"""
|
||||
if target == "aicore":
|
||||
fusion_type = "block_fusion"
|
||||
type_info = None
|
||||
fake_estimate = 1000
|
||||
fake_blocks = [1 for g in graphs]
|
||||
return ParalGain(fusion_type, fake_estimate, fake_estimate, fake_blocks, type_info)
|
||||
return block_parallel_estimate(graphs)
|
||||
|
|
|
@ -20,15 +20,21 @@ from mindspore import log as logger
|
|||
from . import model
|
||||
|
||||
|
||||
def estimate_ops(json_str: str):
|
||||
def estimate_ops(json_str):
|
||||
"""Call cost model to estimate ops."""
|
||||
try:
|
||||
json_obj = json.loads(json_str)
|
||||
graph_descs = json_obj["graph_desc"]
|
||||
graphs = []
|
||||
target = None
|
||||
for gd in graph_descs:
|
||||
if target is None:
|
||||
target = gd['process']
|
||||
elif target != gd['process']:
|
||||
logger.error("Parallel fusion does not support multi-target({} and {})".format(target, gd['process']))
|
||||
return None
|
||||
graphs.append(model.load_composite(gd).graph)
|
||||
estimation = model.parallel_estimate(graphs)
|
||||
estimation = model.parallel_estimate(graphs, target)
|
||||
res = (estimation.block_assign, estimation.gain,
|
||||
estimation.fusion_type, estimation.type_info)
|
||||
return res
|
||||
|
@ -37,12 +43,13 @@ def estimate_ops(json_str: str):
|
|||
return None
|
||||
|
||||
|
||||
def estimate_calulation_amount(json_str: str):
|
||||
def estimate_calulation_amount(json_str):
|
||||
"""Call cost model to estimate calculation amount of op."""
|
||||
try:
|
||||
graph_desc = json.loads(json_str)
|
||||
target = graph_desc['process']
|
||||
comp = model.load_composite(graph_desc)
|
||||
estimation = model.parallel_estimate([comp.graph])
|
||||
estimation = model.parallel_estimate([comp.graph], target)
|
||||
return estimation.bottleneck
|
||||
except jd.JSONDecodeError:
|
||||
logger.error(traceback.format_exc())
|
||||
|
|
|
@ -56,6 +56,7 @@ using opt::GetitemTuple;
|
|||
using opt::GraphOptimizer;
|
||||
|
||||
namespace {
|
||||
auto constexpr PARALLEL_OPS_LIMIT = 7;
|
||||
inline unsigned int GetPassLevelByFlag(bool flag) { return flag ? OptLevel_1 : OptLevel_MAX; }
|
||||
} // namespace
|
||||
|
||||
|
@ -175,8 +176,14 @@ PassManagerPtr GraphKernelOptimizer::HighLevelOpt2() const {
|
|||
PassManagerPtr GraphKernelOptimizer::Combine() const {
|
||||
auto pm = std::make_shared<GraphKernelPassManager>(5, "combine");
|
||||
// Enable parallel fusion for gpu device
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
auto target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
auto level = GetPassLevelByFlag(GraphKernelFlags::GetInstance().enable_parallel_fusion);
|
||||
pm->AddPass(std::make_shared<ParallelOpFusion>(kGPUDevice, ParallelConfig(7)), level, is_gpu);
|
||||
// Atomic-add GraphKernel node may be linked directly to UpdateState, it should be spread before parallel fusion!
|
||||
pm->AddPass(std::make_shared<SpreadUpdateState>(), level);
|
||||
pm->AddPass(std::make_shared<ParallelOpFusion>(target, ParallelConfig(PARALLEL_OPS_LIMIT)), level,
|
||||
is_gpu || is_ascend);
|
||||
|
||||
return pm;
|
||||
}
|
||||
|
|
|
@ -112,8 +112,8 @@ FusionInfoPtr ParallelCostModel::ProcessFusionInfo(const py::object &fusion_type
|
|||
}
|
||||
|
||||
ParallelCostModelPtr ParellelCostModelWarehouse::GetParallelCostModel(const std::string &target) const {
|
||||
if (target != kGPUDevice) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model only support " << kGPUDevice << " now.";
|
||||
if (target != kGPUDevice && target != kAscendDevice) {
|
||||
MS_LOG(EXCEPTION) << "Parallel cost model do not support " << target << " now.";
|
||||
}
|
||||
return cost_model_;
|
||||
}
|
||||
|
|
|
@ -19,9 +19,12 @@
|
|||
#include <algorithm>
|
||||
#include <list>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include "utils/context/graph_kernel_flags.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
|
||||
#include "backend/kernel_compiler/common_utils.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/optimizer/graph_kernel/update_state_formatter.h"
|
||||
|
@ -269,6 +272,32 @@ bool WhiteOpsFilter(const AnfNodePtr &node) {
|
|||
return session::AnfRuntimeAlgorithm::IsGraphKernel(node) || IsOneOf(node, whiteable_ops);
|
||||
}
|
||||
|
||||
bool Unfavorable(const AnfNodePtr &node) {
|
||||
// Parallel cannot work with stitching for now.
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto input = cnode->input(kAnfPrimitiveIndex);
|
||||
if (!IsValueNode<FuncGraph>(input)) {
|
||||
return AnfAlgo::HasNodeAttr(kAttrStitch, cnode);
|
||||
}
|
||||
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(input);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
AnfNodePtrList sub_nodes;
|
||||
kernel::GetValidKernelNodes(func_graph, &sub_nodes);
|
||||
for (auto sub_node : sub_nodes) {
|
||||
auto sub_cnode = sub_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(sub_cnode);
|
||||
if (AnfAlgo::HasNodeAttr(kAttrStitch, sub_cnode)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Parallelizable(const AnfNodePtr &node) { return WhiteOpsFilter(node) && !Unfavorable(node); }
|
||||
|
||||
std::vector<AnfNodePtrList> SearchFromNodes(const AnfNodePtrList &nodes,
|
||||
const std::function<bool(const AnfNodePtr &)> &filter_func,
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
|
||||
|
@ -320,7 +349,7 @@ void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
|
|||
if (auto iter = node_rels.find(node); iter != node_rels.end()) {
|
||||
const auto &pre_nodes = get_related_nodes(iter->second);
|
||||
AnfNodePtrList related_nodes(pre_nodes.begin(), pre_nodes.end());
|
||||
groups->push_back(SearchFromNodes(related_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
|
||||
groups->push_back(SearchFromNodes(related_nodes, Parallelizable, node_rels, is_backward, seen));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -337,7 +366,7 @@ void SearchStreamFromMultiRelationNode(const AnfNodePtrList &multi_nodes,
|
|||
void SearchStreamFromUnidirectionalNode(const AnfNodePtrList &ud_nodes,
|
||||
const OrderedMap<AnfNodePtr, NodeRelation> &node_rels, bool is_backward,
|
||||
std::vector<std::vector<AnfNodePtrList>> *groups, std::set<AnfNodePtr> *seen) {
|
||||
groups->push_back(SearchFromNodes(ud_nodes, WhiteOpsFilter, node_rels, is_backward, seen));
|
||||
groups->push_back(SearchFromNodes(ud_nodes, Parallelizable, node_rels, is_backward, seen));
|
||||
|
||||
// Erase empty groups.
|
||||
for (auto iter = groups->begin(); iter != groups->end();) {
|
||||
|
@ -358,8 +387,9 @@ std::string DumpNode(const AnfNodePtr &node) {
|
|||
return buf.str();
|
||||
}
|
||||
|
||||
void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups) {
|
||||
MS_LOG(INFO) << "There are " << groups.size() << " parallel groups, their detail is: ";
|
||||
void DumpParallelGroups(const std::vector<std::vector<AnfNodePtrList>> &groups, const std::string &title = "") {
|
||||
MS_LOG(INFO) << "[" << title << "]"
|
||||
<< "There are " << groups.size() << " parallel groups, their detail is: ";
|
||||
int i = 0;
|
||||
for (const auto group : groups) {
|
||||
std::stringstream buf;
|
||||
|
@ -465,7 +495,7 @@ std::vector<std::vector<AnfNodePtrList>> ParallelOpFusion::SearchParallelGroups(
|
|||
SearchStreamFromMultiRelationNode(mul_outs_nodes, node_rels, false, &groups, &seen);
|
||||
SearchStreamFromUnidirectionalNode(no_in_nodes, node_rels, false, &groups, &seen);
|
||||
|
||||
DumpParallelGroups(groups);
|
||||
DumpParallelGroups(groups, "Dependency Analyze");
|
||||
return groups;
|
||||
}
|
||||
|
||||
|
@ -725,8 +755,63 @@ bool ParallelOpFusion::CreateParallelOpSubGraphs(const std::vector<ParallelInfo>
|
|||
return changed;
|
||||
}
|
||||
|
||||
std::set<AnfNodePtr> CollectCapturedNodes(const std::vector<ParallelInfo> &infos) {
|
||||
std::set<AnfNodePtr> captured;
|
||||
std::for_each(infos.cbegin(), infos.cend(),
|
||||
[&captured](const ParallelInfo &info) { captured.insert(info.nodes().begin(), info.nodes().end()); });
|
||||
return captured;
|
||||
}
|
||||
|
||||
std::vector<std::vector<AnfNodePtrList>> GetParallelGroupsByBfs(const OrderedMap<AnfNodePtr, NodeRelation> &node_rels,
|
||||
const std::set<AnfNodePtr> &exclude) {
|
||||
std::vector<std::vector<AnfNodePtrList>> groups;
|
||||
// BFS
|
||||
std::queue<AnfNodePtr> node_que;
|
||||
std::unordered_map<AnfNodePtr, int> outdegrees;
|
||||
for (const auto &[node, ref] : node_rels) {
|
||||
outdegrees[node] = ref.nexts.size();
|
||||
if (outdegrees[node] == 0) {
|
||||
node_que.push(node);
|
||||
}
|
||||
}
|
||||
|
||||
int total_node_num = node_rels.size();
|
||||
while (!node_que.empty()) {
|
||||
std::vector<AnfNodePtrList> group;
|
||||
int node_size = node_que.size();
|
||||
while (node_size--) {
|
||||
auto node = node_que.front();
|
||||
node_que.pop();
|
||||
if (exclude.count(node) == 0 && Parallelizable(node)) {
|
||||
group.push_back({node});
|
||||
}
|
||||
--total_node_num;
|
||||
auto iter = node_rels.find(node);
|
||||
if (iter == node_rels.end()) {
|
||||
MS_LOG(EXCEPTION) << "Internal error in node relationship!";
|
||||
}
|
||||
for (const auto &pre : iter->second.pres) {
|
||||
if (--outdegrees[pre] == 0) {
|
||||
node_que.push(pre);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!group.empty()) {
|
||||
groups.push_back(group);
|
||||
}
|
||||
}
|
||||
|
||||
if (total_node_num > 0) {
|
||||
MS_LOG(EXCEPTION) << "There is circle in analyze graph!";
|
||||
}
|
||||
DumpParallelGroups(groups, "BFS");
|
||||
return groups;
|
||||
}
|
||||
|
||||
bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
parallel_level_ = GraphKernelFlags::GetInstance().parallel_ops_level;
|
||||
|
||||
(void)std::make_shared<ShrinkUpdateState>()->Run(graph);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -741,6 +826,15 @@ bool ParallelOpFusion::Run(const FuncGraphPtr &graph) {
|
|||
auto groups = SearchParallelGroups(node_rels);
|
||||
auto parallel_infos = SearchFusableParallelCNodes(groups);
|
||||
|
||||
// Search in BFS for left nodes.
|
||||
if (parallel_level_ > 0) {
|
||||
auto exclued_nodes = CollectCapturedNodes(parallel_infos);
|
||||
auto groups_bfs = GetParallelGroupsByBfs(node_rels, exclued_nodes);
|
||||
auto bfs_parallel_infos = SearchFusableParallelCNodes(groups_bfs);
|
||||
|
||||
parallel_infos.insert(parallel_infos.end(), bfs_parallel_infos.begin(), bfs_parallel_infos.end());
|
||||
}
|
||||
|
||||
// Create core-fuse subgraph and change origin graph.
|
||||
bool changed = CreateParallelOpSubGraphs(parallel_infos, kernel_graph);
|
||||
(void)std::make_shared<SpreadUpdateState>()->Run(graph);
|
||||
|
|
|
@ -120,6 +120,7 @@ class ParallelOpFusion : public opt::Pass {
|
|||
ParallelCostModelPtr cost_model_ptr_;
|
||||
std::set<AnfNodePtr> virtual_noout_nodes_;
|
||||
std::set<AnfNodePtr> ignore_noin_nodes_;
|
||||
unsigned int parallel_level_{0};
|
||||
};
|
||||
using ParallelOpFusionPtr = std::shared_ptr<ParallelOpFusion>;
|
||||
} // namespace mindspore::graphkernel
|
||||
|
|
|
@ -213,6 +213,7 @@ void GraphKernelFlags::RegisterFlags(std::map<std::string, std::string> *flag_ma
|
|||
// Integer flags
|
||||
reg.AddFlag("online_tuning", &online_tuning);
|
||||
reg.AddFlag("fusion_ops_level", &fusion_ops_level, is_ascend ? OpLevel_0 : OpLevel_MAX);
|
||||
reg.AddFlag("parallel_ops_level", ¶llel_ops_level);
|
||||
|
||||
// String flags
|
||||
reg.AddFlag("repository_path", &repository_path);
|
||||
|
@ -242,6 +243,7 @@ std::string GraphKernelFlags::DumpAllFlags() const {
|
|||
|
||||
json["opt_level"] = opt_level;
|
||||
json["fusion_ops_level"] = fusion_ops_level;
|
||||
json["parallel_ops_level"] = parallel_ops_level;
|
||||
json["online_tuning"] = online_tuning;
|
||||
|
||||
json["repository_path"] = repository_path;
|
||||
|
|
|
@ -83,6 +83,13 @@ class GraphKernelFlags {
|
|||
*/
|
||||
bool enable_parallel_fusion{false};
|
||||
|
||||
/**
|
||||
* Parallel AKG's operators by level.
|
||||
* 0: Parallel operators by local data relation analyzation with less memory influence.
|
||||
* 1: Parallel operators with global analyzation with more memory influence.
|
||||
*/
|
||||
unsigned int parallel_ops_level{OpLevel_0};
|
||||
|
||||
/**
|
||||
* Enable low precision in data transferring between graph kernel and computing in graph kernel
|
||||
* in graph kernel.
|
||||
|
|
Loading…
Reference in New Issue