From 60fc029307e6239de6cfd37f869be2031e54cb36 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Tue, 13 Oct 2020 09:35:28 +0800 Subject: [PATCH] add atomic clean for all type communication input --- .../ccsrc/runtime/device/ascend/kernel_build_ascend.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc index 58d0480e487..5f6b198c792 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_build_ascend.cc @@ -263,13 +263,17 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) { std::map> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); for (const auto &anf_node : kernel_graph->execution_order()) { std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); + bool is_comm_input = false; if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { auto indexes = comm_input_info_map[anf_node]; AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); + is_comm_input = true; } - if (apply_function_name == prim::kPrimMaxPoolGrad->name() && - AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { + if (is_comm_input) { + AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes); + } else if (apply_function_name == prim::kPrimMaxPoolGrad->name() && + AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { auto clear_zero_prim = std::make_shared(kClearZeroOpName); MS_EXCEPTION_IF_NULL(clear_zero_prim); auto new_value_node = NewValueNode(clear_zero_prim);