forked from mindspore-Ecosystem/mindspore
!1522 Simplified graph order
Merge pull request !1522 from 何霞/graph_order
This commit is contained in:
commit
65fe160845
|
@ -349,11 +349,10 @@ void AscendControlParser::LinkArgsToParam(NotNull<KernelGraphPtr> to_graph, NotN
|
|||
|
||||
void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
|
||||
std::set<KernelGraphPtr> memo;
|
||||
(void)RecurseGraph(nullptr, nullptr, root_graph, NOT_NULL(&memo));
|
||||
(void)RecurseGraph(root_graph, NOT_NULL(&memo));
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
|
||||
NotNull<KernelGraphPtr> graph,
|
||||
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
MS_LOG(INFO) << "graph:" << graph->graph_id() << " start";
|
||||
auto print_vector = [&](std::vector<CNodePtr> vec) -> void {
|
||||
|
@ -366,52 +365,38 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
|
|||
return {};
|
||||
}
|
||||
memo->insert(graph.get());
|
||||
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
|
||||
graph->SetExecOrderByDefault();
|
||||
|
||||
const std::vector<CNodePtr> &cnodes = graph->execution_order();
|
||||
std::map<uint32_t, CNodePtr> label_map;
|
||||
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
|
||||
std::tie(label_map, label_switch_map) = GetLabelNode(cnodes);
|
||||
|
||||
std::vector<CNodePtr> execution_order;
|
||||
uint32_t child_order_index = 0;
|
||||
|
||||
for (auto &node : cnodes) {
|
||||
execution_order.push_back(node);
|
||||
if (node == graph->get_end_goto()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto label_iter =
|
||||
std::find_if(label_map.begin(), label_map.end(),
|
||||
[node](const std::map<uint32_t, CNodePtr>::value_type iter) { return iter.second == node; });
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
if (label_iter == label_map.end() || !CheckLabelIndex(label_iter->first, 0, label_iter->second, graph)) {
|
||||
if (!CheckLabelIndex(child_order_index, 0, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
auto child_graph = child_graph_order[label_iter->first];
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
if (child_graph == graph->parent_graph()) {
|
||||
continue;
|
||||
}
|
||||
std::map<uint32_t, CNodePtr> child_label_map;
|
||||
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
|
||||
auto child_execution_order =
|
||||
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
std::vector<uint32_t> label_list = label_switch_map.find(node)->second;
|
||||
std::reverse(label_list.begin(), label_list.end());
|
||||
for (size_t i = 0; i < label_list.size(); ++i) {
|
||||
if (!CheckLabelIndex(label_iter->first + i, label_list[i], label_iter->second, graph)) {
|
||||
std::vector<uint32_t> label_switch_list = GetLabelSwitchList(node);
|
||||
for (auto iter = label_switch_list.rbegin(); iter != label_switch_list.rend(); ++iter) {
|
||||
if (!CheckLabelIndex(child_order_index, *iter, node, graph)) {
|
||||
MS_LOG(EXCEPTION) << "Check label index fail";
|
||||
}
|
||||
auto child_graph = child_graph_order[label_iter->first + i];
|
||||
auto child_graph = graph->child_graph_order()[child_order_index++];
|
||||
if (child_graph == graph->parent_graph()) {
|
||||
continue;
|
||||
}
|
||||
std::map<uint32_t, CNodePtr> child_label_map;
|
||||
std::tie(child_label_map, std::ignore) = GetLabelNode(child_graph->execution_order());
|
||||
auto child_execution_order =
|
||||
RecurseGraph(child_label_map.begin()->second, child_label_map.rbegin()->second, NOT_NULL(child_graph), memo);
|
||||
auto child_execution_order = RecurseGraph(NOT_NULL(child_graph), memo);
|
||||
execution_order.insert(execution_order.end(), child_execution_order.begin(), child_execution_order.end());
|
||||
}
|
||||
}
|
||||
|
@ -421,6 +406,15 @@ std::vector<CNodePtr> AscendControlParser::RecurseGraph(const CNodePtr &cur_labe
|
|||
return execution_order;
|
||||
}
|
||||
|
||||
std::vector<uint32_t> AscendControlParser::GetLabelSwitchList(const CNodePtr &node) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
|
||||
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
|
||||
}
|
||||
|
||||
bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cur_label,
|
||||
NotNull<KernelGraphPtr> graph) {
|
||||
const std::vector<std::shared_ptr<KernelGraph>> &child_graph_order = graph->child_graph_order();
|
||||
|
@ -458,31 +452,6 @@ bool AscendControlParser::CheckLabelIndex(uint32_t order_index, uint32_t label_i
|
|||
return true;
|
||||
}
|
||||
|
||||
std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> AscendControlParser::GetLabelNode(
|
||||
const std::vector<CNodePtr> &nodes) {
|
||||
std::map<uint32_t, CNodePtr> label_map;
|
||||
std::map<CNodePtr, std::vector<uint32_t>> label_switch_map;
|
||||
// record child graph
|
||||
uint32_t index = 0;
|
||||
for (auto &node : nodes) {
|
||||
if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelGoto)) {
|
||||
label_map[index++] = node;
|
||||
} else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimLabelSwitch)) {
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, node)) {
|
||||
MS_LOG(EXCEPTION) << "LabelSwitchKernel has no attr label_switch_list";
|
||||
}
|
||||
auto primitive = AnfAlgo::GetCNodePrimitive(node);
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
std::vector<uint32_t> label_list = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList));
|
||||
label_switch_map.insert({node, label_list});
|
||||
for (size_t i = 0; i < label_list.size(); ++i) {
|
||||
label_map[index++] = node;
|
||||
}
|
||||
}
|
||||
}
|
||||
return {label_map, label_switch_map};
|
||||
}
|
||||
|
||||
void AscendControlParser::UpdateChildGraphOrder(NotNull<KernelGraphPtr> kg) {
|
||||
MS_LOG(INFO) << "graph id:" << kg->graph_id();
|
||||
kg->SetExecOrderByDefault();
|
||||
|
|
|
@ -60,12 +60,10 @@ class AscendControlParser {
|
|||
static CNodePtr GetNextRealKernel(const std::vector<CNodePtr> &list, size_t start);
|
||||
|
||||
// root graph order
|
||||
static std::tuple<std::map<uint32_t, CNodePtr>, std::map<CNodePtr, std::vector<uint32_t>>> GetLabelNode(
|
||||
const std::vector<CNodePtr> &nodes);
|
||||
static std::vector<uint32_t> GetLabelSwitchList(const CNodePtr &node);
|
||||
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
|
||||
NotNull<KernelGraphPtr> graph);
|
||||
static std::vector<CNodePtr> RecurseGraph(const CNodePtr &cur_label_goto, const CNodePtr &end_label_goto,
|
||||
NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
static std::vector<CNodePtr> RecurseGraph(NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo);
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
Loading…
Reference in New Issue