!1522 Simplified graph order

Merge pull request !1522 from 何霞/graph_order
This commit is contained in:
mindspore-ci-bot 2020-05-28 18:53:19 +08:00 committed by Gitee
commit 65fe160845
2 changed files with 23 additions and 56 deletions

View File

@ -349,11 +349,10 @@ void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotN
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
std::set<KernelGraphPtr> memo;
(void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo));
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
}
std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
NotNull<KernelGraphPtr> graph,
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
NotNull<std::set<KernelGraphPtr> *> memo) {
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
auto print_vector = [&](std::vector<CNodePtr> vec) -> void {
@ -366,52 +365,38 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
return {};
}
memo->insert(graph.get());
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
graph->SetExecOrderByDefault();
const std::vector<CNodePtr> &cnodes = graph->execution_order();
std::map<uint32_t, CNodePtr> label_map;
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes);
std::vector<CNodePtr> execution_order;
uint32_t child_order_index = 0;
for (auto &node : cnodes) {
execution_order.push_back(node);
if (node == graph->get_end_goto()) {
continue;
}
auto label_iter =
std::find_if(label_map.begin(), label_map.end(),
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; });
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = child_graph_order[label_iter->first];
auto child_graph = graph->child_graph_order()[child_order_index++];
if (child_graph == graph->parent_graph()) {
continue;
}
std::map<uint32_t, CNodePtr> child_label_map;
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
auto child_execution_order =
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
std::vector<uint32_t> label_list = label_switch_map.find(node)->second;
std::reverse(label_list.begin(), label_list.end());
for (size_t i = 0; i < label_list.size(); ++i) {
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) {
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
MS_LOG(EXCEPTION) << "Check label index fail";
}
auto child_graph = child_graph_order[label_iter->first + i];
auto child_graph = graph->child_graph_order()[child_order_index++];
if (child_graph == graph->parent_graph()) {
continue;
}
std::map<uint32_t, CNodePtr> child_label_map;
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
auto child_execution_order =
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
}
}
@ -421,6 +406,15 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
return execution_order;
}
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
}
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
NotNull<KernelGraphPtr> graph) {
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
@ -458,31 +452,6 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
return true;
}
std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> AscendControlParser::GetLabelNode(
const std::vector<CNodePtr> &nodes) {
std::map<uint32_t, CNodePtr> label_map;
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
// record child graph
uint32_t index = 0;
for (auto &node : nodes) {
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
label_map[index++] = node;
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
label_switch_map.insert({node, label_list});
for (size_t i = 0; i < label_list.size(); ++i) {
label_map[index++] = node;
}
}
}
return {label_map, label_switch_map};
}
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
MS_LOG(INFO) << "graph id:" << kg->graph_id();
kg->SetExecOrderByDefault();

View File

@ -60,12 +60,10 @@ class AscendControlParser {
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
// root graph order
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
const std::vector<CNodePtr> &nodes);
static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node);
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
NotNull<KernelGraphPtr> graph);
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
};
} // namespace session
} // namespace mindspore