From 2e2e7a28ae0160a350243c657b7adf90bcc8834b Mon Sep 17 00:00:00 2001 From: "Etone.Chan" Date: Mon, 27 Apr 2020 22:13:02 +0800 Subject: [PATCH] refactor buffer fusion --- .../ascend/buffer_fusion/buffer_fusion.cc | 143 +++++++----------- .../ascend/buffer_fusion/buffer_fusion.h | 1 - .../ccsrc/session/anf_runtime_algorithm.cc | 2 + mindspore/ccsrc/utils/utils.h | 2 + 4 files changed, 58 insertions(+), 90 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index abacb9137d1..8581f1165df 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -261,23 +261,24 @@ CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::v return buffer_fusion_kernel; } -kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list_in, - const std::vector &inputs_list, +kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, const std::vector &outputs_list) { MS_LOG(DEBUG) << "Start Create Kernel Info"; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; // inputs format and data type std::vector inputs_format; std::vector inputs_data_type; - for (auto node : inputs_list_in) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - for (size_t input_index = 1; input_index < inputs.size(); ++input_index) { - if (std::find(inputs_list.begin(), inputs_list.end(), inputs[input_index]) != inputs_list.end()) { - inputs_format.push_back(AnfAlgo::GetInputFormat(node, input_index - 1)); - inputs_data_type.push_back(AnfAlgo::GetInputDeviceDataType(node, input_index - 1)); - } + for (const auto &input : inputs_list) { + if (input->isa() && AnfAlgo::GetCNodeName(input) == prim::kPrimTupleGetItem->name()) { + auto tuple_getitem = input->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + inputs_format.push_back(AnfAlgo::GetOutputFormat( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + } else { + inputs_format.push_back(AnfAlgo::GetOutputFormat(input, 0)); + inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(input, 0)); } } // outputs format and data type @@ -360,62 +361,6 @@ void ReplaceOldNode(std::unordered_map *buffer_fusi } } -void GetInputList(const CNodePtr &node, const int32_t cur_fusion_id, std::vector *inputs_list) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(inputs_list); - auto &inputs = node->inputs(); - for (size_t input_index = 1; input_index < inputs.size(); ++input_index) { - auto input = inputs[input_index]; - if (AnfAlgo::IsRealCNodeKernel(input)) { - if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input)) { - auto fusion_id = AnfAlgo::GetNodeAttr(input, kOpAttrFusionId); - if (fusion_id != cur_fusion_id) { - inputs_list->push_back(input); - } - } else { - inputs_list->push_back(input); - } - } else if (input->isa()) { - for (auto &input_in : input->cast()->inputs()) { - if (AnfAlgo::IsRealCNodeKernel(input_in)) { - if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, input_in)) { - auto fusion_id = AnfAlgo::GetNodeAttr(input_in, kOpAttrFusionId); - if (fusion_id != cur_fusion_id) { - inputs_list->push_back(input); - } - } else { - inputs_list->push_back(input); - } - } - } - } else { - inputs_list->push_back(input); - } - } -} - -void CheckCurrentNodeIsInput(const CNodePtr &node, const int32_t &cur_fusion_id, - std::unordered_map *buffer_fusion_infos) { - MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - if ((*buffer_fusion_infos).find(cur_fusion_id) == (*buffer_fusion_infos).end()) { - BufferFusionInfo_t buffer_fusion_info; - (*buffer_fusion_infos)[cur_fusion_id] = buffer_fusion_info; - } - std::vector inputs_list; - GetInputList(node, cur_fusion_id, &inputs_list); - if (!inputs_list.empty()) { - if (!(*buffer_fusion_infos)[cur_fusion_id].inputs_list.empty()) { - (void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list.insert( - (*buffer_fusion_infos)[cur_fusion_id].inputs_list.end(), inputs_list.begin(), inputs_list.end()); - (void)(*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.insert( - (*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.end(), node); - } else { - (*buffer_fusion_infos)[cur_fusion_id].inputs_list = inputs_list; - (*buffer_fusion_infos)[cur_fusion_id].inputs_list_in.push_back(node); - } - } -} - void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, std::unordered_map *buffer_fusion_infos) { MS_EXCEPTION_IF_NULL(buffer_fusion_infos); @@ -429,6 +374,45 @@ void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, } } +void GetFusionScopeInputNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + auto cnode = node->cast(); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == + fusion_info.anf_nodes.end()) { + if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), + (*buffer_fusion_infos)[fusion_id].inputs_list.end(), + cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); + } + } + } + } + } +} + +bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + auto getitem1 = node1->cast(); + auto getitem2 = node2->cast(); + MS_EXCEPTION_IF_NULL(getitem1); + MS_EXCEPTION_IF_NULL(getitem2); + auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); + auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); + return output_idx1 < output_idx2; +} + void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, std::unordered_map *buffer_fusion_infos) { MS_EXCEPTION_IF_NULL(kernel_graph); @@ -454,14 +438,7 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), std::back_inserter(tuple_getitem_nodes), [](const std::pair &use_node) { return use_node.first; }); - std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), - [](const AnfNodePtr &node1, const AnfNodePtr &node2) { - auto getitem1 = node1->cast(); - auto getitem2 = node2->cast(); - auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); - auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); - return output_idx1 < output_idx2; - }); + std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); for (auto getitem : tuple_getitem_nodes) { auto getitem_ptr = getitem->cast(); auto input2 = getitem_ptr->input(2); @@ -634,24 +611,12 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, std::unordered_map *buffer_fusion_infos) const { MS_EXCEPTION_IF_NULL(buffer_fusion_infos); - std::vector node_list = TopoSort(kernel_graph->get_return()); - for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node)) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { - auto cur_fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); - CheckCurrentNodeIsInput(cnode, cur_fusion_id, buffer_fusion_infos); - } - } GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); + GetFusionScopeInputNodeList(kernel_graph, buffer_fusion_infos); GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); for (auto &buffer_fusion_info : *buffer_fusion_infos) { buffer_fusion_info.second.kernel_build_info = - CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list_in, buffer_fusion_info.second.inputs_list, - buffer_fusion_info.second.outputs_list); + CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); } } diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h index 9bed7217dd1..f2fa63601b9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h @@ -30,7 +30,6 @@ namespace opt { struct BufferFusionInfo_t { std::vector anf_nodes; std::vector inputs_list; - std::vector inputs_list_in; std::vector outputs_list; kernel::KernelBuildInfoPtr kernel_build_info; }; diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index e1a18d95da0..dbf70979706 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -816,6 +816,8 @@ size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_n MS_EXCEPTION_IF_NULL(anf_node); static std::map> spec_node_list = { {prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}}, + {kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}}, + {kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}}, {prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}}, {prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}}, {prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}}, diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4dfc4baa311..904acff975a 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -122,6 +122,8 @@ constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; constexpr auto kReluV2OpName = "ReLUV2"; constexpr auto kReluGradV2OpName = "ReluGradV2"; +constexpr auto kFusionOpConv2DBackpropInputReluGradV2Name = "FusionOp_Conv2DBackpropInput_ReluGradV2"; +constexpr auto kFusionOpConv2DBackpropInputAddNReluGradV2Name = "FusionOp_Conv2DBackpropInput_AddN_ReluGradV2"; // attr key name constexpr auto kAttrInputNames = "input_names";