From 38e0d98eb5ab1010f0e888a52a8dc5d2fe486061 Mon Sep 17 00:00:00 2001 From: etone-chan Date: Thu, 7 May 2020 20:32:44 +0800 Subject: [PATCH] refactor fusion id implement of buffer fusion --- .../ascend/buffer_fusion/buffer_fusion.cc | 79 ++++++++----------- .../ascend/buffer_fusion/buffer_fusion.h | 15 +++- .../common/fusion_id_allocator.cc | 46 +++++++++++ .../pre_activate/common/fusion_id_allocator.h | 42 ++++++++++ mindspore/ccsrc/utils/utils.h | 1 + 5 files changed, 135 insertions(+), 48 deletions(-) create mode 100644 mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc create mode 100644 mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index a2313a50d07..41e09910650 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -32,6 +32,7 @@ #include "operator/ops.h" #include "device/kernel_info.h" #include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" namespace mindspore { namespace opt { @@ -79,20 +80,6 @@ void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { } #endif -void SetAnfNodeFusionId(const FusedNodeRecord &record_node) { - MS_LOG(DEBUG) << "Size of opt vector to be fused is " << record_node.size(); - int32_t id = 1; - for (auto &record : record_node) { - MS_LOG(DEBUG) << "No" << id << ", opt vector to be fused contain " << record.size() << " opt."; - for (const auto &candidate : record) { - ValuePtr fusion_id_v = MakeValue(id); - AnfAlgo::SetNodeAttr(kOpAttrFusionId, fusion_id_v, candidate); - MS_LOG(DEBUG) << "No " << id << ": " << candidate->DebugString(); - } - id++; - } -} - bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set *record, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(manager); MS_EXCEPTION_IF_NULL(record); @@ -482,11 +469,18 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector *fused_set, FusedNodeRecord *candidate_fusion) { +void BufferFusion::SetRecordFusionId(const std::unordered_set &record) { + auto id = fusion_id_allocator.AllocateFusionId(); + for (auto node : record) { + fusion_id_allocator.SetFusionId(node, id); + } +} + +void BufferFusion::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(fused_set); MS_EXCEPTION_IF_NULL(candidate_fusion); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); @@ -496,14 +490,13 @@ void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv); std::unordered_set record{cnode, conv}; candidate_fusion->push_back(record); - fused_set->insert(record.begin(), record.end()); + SetRecordFusionId(record); } } -void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - std::unordered_set *fused_set, FusedNodeRecord *candidate_fusion) { +void BufferFusion::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(fused_set); MS_EXCEPTION_IF_NULL(candidate_fusion); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); @@ -520,14 +513,13 @@ void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, cons AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); std::unordered_set record{cnode, bnupdate}; candidate_fusion->push_back(record); - fused_set->insert(record.begin(), record.end()); + SetRecordFusionId(record); } } -void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, - std::unordered_set *fused_set, FusedNodeRecord *candidate_fusion) { +void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); - MS_EXCEPTION_IF_NULL(fused_set); MS_EXCEPTION_IF_NULL(candidate_fusion); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); @@ -548,41 +540,37 @@ void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, c AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); std::unordered_set record{cnode, relu_input, bnupdate}; candidate_fusion->push_back(record); - fused_set->insert(record.begin(), record.end()); + SetRecordFusionId(record); } } } -void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered_set *fused_set, - FusedNodeRecord *candidate_fusion) { - MS_EXCEPTION_IF_NULL(fused_set); +void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(candidate_fusion); std::vector node_list = TopoSort(kernel_graph.get_return()); for (auto &node : node_list) { - if (!AnfAlgo::IsRealCNodeKernel(node) || fused_set->find(node) != fused_set->end()) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) { continue; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) { - MatchConvBnreduce(cnode, kernel_graph, fused_set, candidate_fusion); + MatchConvBnreduce(cnode, kernel_graph, candidate_fusion); } else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) { auto relu_input = cnode->input(1); if (relu_input->isa() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) { - MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion); + MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion); } else if (relu_input->isa() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) { - MatchBnupdateRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion); + MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion); } } } } -void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unordered_set *fused_set, - FusedNodeRecord *candidate_fusion) { +void BufferFusion::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); - MS_EXCEPTION_IF_NULL(fused_set); MS_EXCEPTION_IF_NULL(candidate_fusion); auto return_node = kernel_graph.get_return(); @@ -599,7 +587,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord MS_EXCEPTION_IF_NULL(node); todo.pop_front(); std::unordered_set record; - if (visited_set.find(node) != visited_set.end() || fused_set->find(node) != fused_set->end()) { + if (visited_set.find(node) != visited_set.end() || fusion_id_allocator.HasFusionIdAttr(node)) { continue; } // Only fuse real cnode @@ -616,7 +604,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode); if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) { candidate_fusion->push_back(record); - fused_set->insert(record.begin(), record.end()); + SetRecordFusionId(record); } if (record.find(cnode) == record.end()) { todo.push_back(cnode); @@ -628,7 +616,6 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord (void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); } } -} // namespace void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, std::unordered_map *buffer_fusion_infos) const { @@ -684,7 +671,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c return change; } -bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const { +bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) { auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); auto return_node = kernel_graph.get_return(); @@ -694,14 +681,11 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g } MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; FusedNodeRecord candidate_fusion; - std::unordered_set fused_set; - MatchOpNamePattern(kernel_graph, &fused_set, &candidate_fusion); - MatchFusionTypePattern(kernel_graph, &fused_set, &candidate_fusion); + MatchOpNamePattern(kernel_graph, &candidate_fusion); + MatchFusionTypePattern(kernel_graph, &candidate_fusion); - if (!candidate_fusion.empty()) { - SetAnfNodeFusionId(candidate_fusion); - } else { + if (candidate_fusion.empty()) { return false; } MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; @@ -741,13 +725,14 @@ bool BufferFusion::Run(const FuncGraphPtr &graph) { auto kernel_graph = graph->cast>(); MS_EXCEPTION_IF_NULL(kernel_graph); + fusion_id_allocator.Init(); if (MatchBufferFusionPattern(*kernel_graph)) { changed = FuseBufferFusionPattern(kernel_graph.get()); } // clear fusion_id attr for (auto &node : graph->nodes()) { if (node != nullptr && node->isa()) { - AnfAlgo::EraseNodeAttr(kOpAttrFusionId, node); + AnfAlgo::EraseNodeAttr(kAttrFusionId, node); } } return changed; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h index f2fa63601b9..fa0a0462202 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h @@ -21,6 +21,7 @@ #include "ir/anf.h" #include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" #include "device/kernel_info.h" #include "kernel/kernel.h" #include "session/kernel_graph.h" @@ -43,12 +44,24 @@ class BufferFusion : public Pass { bool Run(const FuncGraphPtr &graph) override; private: + void SetRecordFusionId(const std::unordered_set &record); + void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); + void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); + void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); + void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); + void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); + void GetBufferFusionInfo(session::KernelGraph *kernel_graph, std::unordered_map *buffer_fusion_infos) const; bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; - bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const; + bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph); bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; + + FusionIdAllocator fusion_id_allocator; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc new file mode 100644 index 00000000000..393546278a5 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/common/fusion_id_allocator.h" +#include "session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; } + +FusionIdAllocator::~FusionIdAllocator() {} + +void FusionIdAllocator::Init() { fusion_id = 0; } + +int32_t FusionIdAllocator::AllocateFusionId() { + fusion_id++; + return fusion_id; +} + +bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { return AnfAlgo::HasNodeAttr(kAttrFusionId, node); } + +int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) { + if (HasFusionIdAttr(node)) { + return AnfAlgo::GetNodeAttr(node, kAttrFusionId); + } + return -1; +} + +void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) { + ValuePtr fusion_id_v = MakeValue(id); + AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h new file mode 100644 index 00000000000..8d0cbb3e76a --- /dev/null +++ b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h @@ -0,0 +1,42 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ + +#include "ir/base.h" + +namespace mindspore { +namespace opt { +class FusionIdAllocator { + public: + FusionIdAllocator(); + virtual ~FusionIdAllocator(); + FusionIdAllocator(const FusionIdAllocator &in) = delete; + FusionIdAllocator &operator=(const FusionIdAllocator &in) = delete; + + void Init(); + int32_t AllocateFusionId(); + bool HasFusionIdAttr(const AnfNodePtr &node); + int32_t GetFusionId(const AnfNodePtr &node); + void SetFusionId(const AnfNodePtr &node, int32_t id); + + private: + int32_t fusion_id; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 588dfa14056..b85bc1b0f84 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -165,6 +165,7 @@ constexpr auto kAttrFusion = "fusion"; constexpr auto kAttrGroup = "group"; constexpr auto kAttrOp = "op"; constexpr auto kAttrIsTraining = "is_training"; +constexpr auto kAttrFusionId = "fusion_id"; // attr value constexpr auto kValueTargetSwitch = "target_switch";