From 311d8ea1f9a939f0bc2013f6a01c4870be2cbdf6 Mon Sep 17 00:00:00 2001 From: huanghui Date: Wed, 29 Jul 2020 11:45:52 +0800 Subject: [PATCH] add exception when different communication op in one segment shared the same input --- .../optimizer/pass/communication_op_fusion.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 8c2aaa85dd2..d8f94c4a2cb 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -16,6 +16,7 @@ #include "backend/optimizer/pass/communication_op_fusion.h" #include +#include #include #include @@ -89,6 +90,13 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) { } return group + op + std::to_string(fusion); } + +void CheckInputs(const std::vector &fusion_inputs) { + std::set inputs_set(fusion_inputs.begin(), fusion_inputs.end()); + if (inputs_set.size() < fusion_inputs.size()) { + MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input"; + } +} } // namespace bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num, @@ -163,6 +171,7 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr MS_EXCEPTION_IF_NULL(cnode); fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); } + CheckInputs(fusion_inputs); AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs); MS_EXCEPTION_IF_NULL(fused_node); auto kernel_info = std::make_shared(); @@ -172,9 +181,6 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr for (size_t idx = start_index; idx <= end_index; ++idx) { auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); - AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node); - AnfAlgo::CopyNodeAttr("op", cnode, fused_node); - AnfAlgo::CopyNodeAttr("group", cnode, fused_node); abstract_list.push_back(cnode->abstract()); } auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index); @@ -182,6 +188,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr auto abstract_tuple = std::make_shared(abstract_list); MS_EXCEPTION_IF_NULL(abstract_tuple); fused_node->set_abstract(abstract_tuple); + AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node); + AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node); + AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node); return fused_node; }