forked from mindspore-Ecosystem/mindspore
!2596 Make assign-node to be before jump-node, ensure child graph can get its input
Merge pull request !2596 from zhoufeng/assign-order-before-jump-r0.5
This commit is contained in:
commit
4f377f2ab4
|
@ -33,6 +33,21 @@ static constexpr size_t kCNodeSwitchLayerLength = 3;
|
|||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
static CNodePtr GetJumpNode(NotNull<KernelGraphPtr> parent_graph, NotNull<KernelGraphPtr> child_graph) {
|
||||
auto &nodes = parent_graph->execution_order();
|
||||
for (auto &node : nodes) {
|
||||
if (IsPrimitiveCNode(node, prim::kPrimLabelGoto) && child_graph->get_start_label() == node->input(kCNodeCallArg)) {
|
||||
return node;
|
||||
} else if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch) &&
|
||||
(child_graph->get_start_label() == node->input(kCNodeSwitchFalse) ||
|
||||
child_graph->get_start_label() == node->input(kCNodeSwitchTrue))) {
|
||||
return node;
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "Cannot find jump node from " << parent_graph->ToString() << " to " << child_graph->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void InitUnionFindSet(NotNull<KernelGraphPtr> kg, const NotNull<UnionFindSet<AnfNodePtr> *> union_find_set,
|
||||
const NotNull<std::set<KernelGraphPtr> *> memo) {
|
||||
if (memo->find(kg.get()) != memo->end()) {
|
||||
|
@ -200,7 +215,8 @@ void AscendControlParser::ChildGraphDataAssign(const std::map<uint32_t, KernelGr
|
|||
if (target_graph_iter == graph_id_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Graph id " << AnfAlgo::GetGraphId(arg.get()) << " not found.";
|
||||
}
|
||||
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(arg), NOT_NULL(parameter));
|
||||
InsertMultipleAssignToGraph(NOT_NULL(target_graph_iter->second), NOT_NULL(kg), NOT_NULL(arg),
|
||||
NOT_NULL(parameter));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -433,7 +449,8 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
|
|||
return {partial_cnode, branch_kg};
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
|
||||
NotNull<KernelGraphPtr> to_graph, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
std::vector<AnfNodePtr> from_outputs = AnfAlgo::GetAllOutput(from, {prim::kPrimTupleGetItem});
|
||||
std::vector<AnfNodePtr> to_outputs = AnfAlgo::GetAllOutput(to, {prim::kPrimTupleGetItem});
|
||||
|
@ -443,18 +460,24 @@ void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg
|
|||
<< to_outputs.size() << "]";
|
||||
}
|
||||
for (size_t i = 0; i < from_outputs.size(); i++) {
|
||||
InsertAssignToGraph(kg, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
||||
auto assign_node = InsertAssignToGraph(from_graph, NOT_NULL(from_outputs[i]), NOT_NULL(to_outputs[i]));
|
||||
if (assign_node != nullptr) {
|
||||
auto jump_node = GetJumpNode(from_graph, to_graph);
|
||||
if (jump_node != nullptr) {
|
||||
InsertControlDependToGraph(from_graph, NOT_NULL(assign_node), NOT_NULL(jump_node));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
AnfNodePtr AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from,
|
||||
NotNull<AnfNodePtr> to) {
|
||||
if (AnfAlgo::OutputAddrExist(from, 0) && AnfAlgo::OutputAddrExist(to, 0) &&
|
||||
AnfAlgo::GetOutputAddr(from, 0) == AnfAlgo::GetOutputAddr(to, 0)) {
|
||||
return;
|
||||
return nullptr;
|
||||
}
|
||||
if (from.get() == to.get()) {
|
||||
return;
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Insert assign to graph " << kg->ToString() << " from " << from->DebugString() << " to "
|
||||
<< to->DebugString();
|
||||
|
@ -466,6 +489,7 @@ void AscendControlParser::InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNul
|
|||
assign_node->set_abstract(to->abstract());
|
||||
// append the assign at the end of from graph
|
||||
InsertDependToGraph(kg, NOT_NULL(assign_node));
|
||||
return assign_node;
|
||||
}
|
||||
|
||||
std::vector<CNodePtr> AscendControlParser::RecurseGraph(NotNull<KernelGraphPtr> graph,
|
||||
|
|
|
@ -52,8 +52,9 @@ class AscendControlParser {
|
|||
const CNodePtr &last_label);
|
||||
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
|
||||
|
||||
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static void InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
|
||||
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
static AnfNodePtr InsertAssignToGraph(NotNull<KernelGraphPtr> kg, NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
|
||||
|
||||
// root graph order
|
||||
static bool CheckLabelIndex(uint32_t order_index, uint32_t label_index, const CNodePtr &cnode,
|
||||
|
|
Loading…
Reference in New Issue