!2712 Revert "Make assign-node to be before jump-node, ensure child graph can get its inputs"

Merge pull request !2712 from zhoufeng/revert-assign-node-order
This commit is contained in:
mindspore-ci-bot 2020-06-29 19:53:39 +08:00 committed by Gitee
commit f8fa03d732
2 changed files with 9 additions and 34 deletions

View File

@ -33,21 +33,6 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
namespace mindspore { namespace mindspore {
namespace session { namespace session {
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
auto &nodes = parent_graph->execution_order();
for (auto &node : nodes) {
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
return node;
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
return node;
}
}
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
return nullptr;
}
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set, static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
const NotNull<std::set<KernelGraphPtr> *> memo) { const NotNull<std::set<KernelGraphPtr> *> memo) {
if (memo->find(kg.get()) != memo->end()) { if (memo->find(kg.get()) != memo->end()) {
@ -215,8 +200,7 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
if (target_graph_iter == graph_id_map.end()) { if (target_graph_iter == graph_id_map.end()) {
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found."; MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
} }
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg), InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
NOT_NULL(parameter));
} }
} }
} }
@ -449,8 +433,7 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
return {partial_cnode, branch_kg}; return {partial_cnode, branch_kg};
} }
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) { NotNull<AnfNodePtr> to) {
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem}); std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
@ -460,24 +443,18 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> fr
<< to_outputs.size() << "]"; << to_outputs.size() << "]";
} }
for (size_t i = 0; i < from_outputs.size(); i++) { for (size_t i = 0; i < from_outputs.size(); i++) {
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i])); InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
if (assign_node != nullptr) {
auto jump_node = GetJumpNode(from_graph, to_graph);
if (jump_node != nullptr) {
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
}
}
} }
} }
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
NotNull<AnfNodePtr> to) { NotNull<AnfNodePtr> to) {
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) && if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) { AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
return nullptr; return;
} }
if (from.get() == to.get()) { if (from.get() == to.get()) {
return nullptr; return;
} }
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to " MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
<< to->DebugString(); << to->DebugString();
@ -489,7 +466,6 @@ AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg,
assign_node->set_abstract(to->abstract()); assign_node->set_abstract(to->abstract());
// append the assign at the end of from graph // append the assign at the end of from graph
InsertDependToGraph(kg, NOT_NULL(assign_node)); InsertDependToGraph(kg, NOT_NULL(assign_node));
return assign_node;
} }
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph, std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,

View File

@ -52,9 +52,8 @@ class AscendControlParser {
const CNodePtr &last_label); const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node); static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph, static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to); static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
// root graph order // root graph order
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode, static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,