add atomic clean for all type communication input

This commit is contained in:
laiyongqiang 2020-10-13 09:35:28 +08:00
parent 39bc43e674
commit 60fc029307
1 changed files with 6 additions and 2 deletions

View File

@ -263,13 +263,17 @@ void KernelBuildPreprocess(mindspore::session::KernelGraph *kernel_graph) {
std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph); std::map<AnfNodePtr, std::vector<size_t>> comm_input_info_map = GetCommunicationOpInputInfo(kernel_graph);
for (const auto &anf_node : kernel_graph->execution_order()) { for (const auto &anf_node : kernel_graph->execution_order()) {
std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node); std::string apply_function_name = AnfAlgo::GetCNodeName(anf_node);
bool is_comm_input = false;
if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) { if (comm_input_info_map.find(anf_node) != comm_input_info_map.end()) {
auto indexes = comm_input_info_map[anf_node]; auto indexes = comm_input_info_map[anf_node];
AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node); AnfAlgo::SetNodeAttr(kAttrAtomicOutputIndexs, MakeValue(indexes), anf_node);
is_comm_input = true;
} }
if (apply_function_name == prim::kPrimMaxPoolGrad->name() && if (is_comm_input) {
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) { AddTbeClearZeroNode(kernel_graph, anf_node, &new_nodes);
} else if (apply_function_name == prim::kPrimMaxPoolGrad->name() &&
AnfAlgo::GetKernelType(anf_node) == KernelType::AKG_KERNEL) {
auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName); auto clear_zero_prim = std::make_shared<Primitive>(kClearZeroOpName);
MS_EXCEPTION_IF_NULL(clear_zero_prim); MS_EXCEPTION_IF_NULL(clear_zero_prim);
auto new_value_node = NewValueNode(clear_zero_prim); auto new_value_node = NewValueNode(clear_zero_prim);