From d9fab277e3e0978f718fefb891adf15f952183ed Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Mon, 14 Sep 2020 16:18:14 +0800 Subject: [PATCH] make ref edage using same address --- .../format_type/deal_ref_trans_and_cast.cc | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 88a8e7a9c04..41de6650a76 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -62,6 +62,16 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { return kernel_with_index; } +void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, + const size_t input_index) { + // record the ref_pair + auto kernel_graph = func_graph->cast(); + MS_EXCEPTION_IF_NULL(kernel_graph); + session::AnfWithOutIndex final_pair = std::make_pair(cnode, output_index); + session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0); + kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); +} + void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, const AnfNodePtr &final_node, size_t final_index, const session::KernelWithIndex &origin_pair) { @@ -88,6 +98,7 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, size_t input_index, const AnfNodePtr &get_item) { AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); + bool need_refresh_ref_addr = false; size_t final_index = output_index; AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); session::KernelWithIndex origin_pair; @@ -109,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); final_index = 0; + need_refresh_ref_addr = true; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); } @@ -119,15 +131,19 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP MS_EXCEPTION_IF_NULL(final_node); final_node->set_scope(cnode->scope()); final_index = 0; + need_refresh_ref_addr = true; MS_LOG(INFO) << "DealRefTransAndCast add cast op, op debug info is " << final_node->DebugString(); } // add ref pair AddRefPairToKernelGraph(func_graph, cnode, get_item, final_node, final_index, origin_pair); + if (need_refresh_ref_addr) { + AddRefNodePairToKernelGraph(func_graph, cnode, output_index, input_index); + } // insert depend if (origin_format != cur_format || origin_type != cur_type) { std::vector depend_nodes{NewValueNode(prim::kPrimDepend), cnode, final_node}; final_node = func_graph->NewCNode(depend_nodes); - MS_LOG(INFO) << "DealRefTransAndCast add denpend, op debug info is " << final_node->DebugString(); + MS_LOG(INFO) << "DealRefTranshwAndCast add denpend, op debug info is " << final_node->DebugString(); } return final_node;