forked from mindspore-Ecosystem/mindspore
!975 refactor fusion id implement of buffer fusion
Merge pull request !975 from Etone.Chan/master
This commit is contained in:
commit
76a2f7c69d
|
@ -32,6 +32,7 @@
|
|||
#include "operator/ops.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "utils/context/ms_context.h"
|
||||
#include "pre_activate/common/fusion_id_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -79,20 +80,6 @@ void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void SetAnfNodeFusionId(const FusedNodeRecord &record_node) {
|
||||
MS_LOG(DEBUG) << "Size of opt vector to be fused is " << record_node.size();
|
||||
int32_t id = 1;
|
||||
for (auto &record : record_node) {
|
||||
MS_LOG(DEBUG) << "No" << id << ", opt vector to be fused contain " << record.size() << " opt.";
|
||||
for (const auto &candidate : record) {
|
||||
ValuePtr fusion_id_v = MakeValue(id);
|
||||
AnfAlgo::SetNodeAttr(kOpAttrFusionId, fusion_id_v, candidate);
|
||||
MS_LOG(DEBUG) << "No " << id << ": " << candidate->DebugString();
|
||||
}
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *record, const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(record);
|
||||
|
@ -482,11 +469,18 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
|||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
||||
void BufferFusion::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) {
|
||||
auto id = fusion_id_allocator.AllocateFusionId();
|
||||
for (auto node : record) {
|
||||
fusion_id_allocator.SetFusionId(node, id);
|
||||
}
|
||||
}
|
||||
|
||||
void BufferFusion::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fused_set);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -496,14 +490,13 @@ void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel
|
|||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, conv};
|
||||
candidate_fusion->push_back(record);
|
||||
fused_set->insert(record.begin(), record.end());
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
|
||||
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
||||
void BufferFusion::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fused_set);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -520,14 +513,13 @@ void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, cons
|
|||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
|
||||
candidate_fusion->push_back(record);
|
||||
fused_set->insert(record.begin(), record.end());
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
|
||||
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
||||
void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(fused_set);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
|
@ -548,41 +540,37 @@ void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, c
|
|||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate};
|
||||
candidate_fusion->push_back(record);
|
||||
fused_set->insert(record.begin(), record.end());
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(fused_set);
|
||||
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fused_set->find(node) != fused_set->end()) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) {
|
||||
MatchConvBnreduce(cnode, kernel_graph, fused_set, candidate_fusion);
|
||||
MatchConvBnreduce(cnode, kernel_graph, candidate_fusion);
|
||||
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName ||
|
||||
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) {
|
||||
auto relu_input = cnode->input(1);
|
||||
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) {
|
||||
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
|
||||
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
||||
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
|
||||
MatchBnupdateRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
|
||||
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
void BufferFusion::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(fused_set);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
||||
auto return_node = kernel_graph.get_return();
|
||||
|
@ -599,7 +587,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
todo.pop_front();
|
||||
std::unordered_set<AnfNodePtr> record;
|
||||
if (visited_set.find(node) != visited_set.end() || fused_set->find(node) != fused_set->end()) {
|
||||
if (visited_set.find(node) != visited_set.end() || fusion_id_allocator.HasFusionIdAttr(node)) {
|
||||
continue;
|
||||
}
|
||||
// Only fuse real cnode
|
||||
|
@ -616,7 +604,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|||
cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode);
|
||||
if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) {
|
||||
candidate_fusion->push_back(record);
|
||||
fused_set->insert(record.begin(), record.end());
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
if (record.find(cnode) == record.end()) {
|
||||
todo.push_back(cnode);
|
||||
|
@ -628,7 +616,6 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|||
(void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
|
||||
|
@ -684,7 +671,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
|
|||
return change;
|
||||
}
|
||||
|
||||
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const {
|
||||
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto return_node = kernel_graph.get_return();
|
||||
|
@ -694,14 +681,11 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
|
|||
}
|
||||
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
|
||||
FusedNodeRecord candidate_fusion;
|
||||
std::unordered_set<AnfNodePtr> fused_set;
|
||||
|
||||
MatchOpNamePattern(kernel_graph, &fused_set, &candidate_fusion);
|
||||
MatchFusionTypePattern(kernel_graph, &fused_set, &candidate_fusion);
|
||||
MatchOpNamePattern(kernel_graph, &candidate_fusion);
|
||||
MatchFusionTypePattern(kernel_graph, &candidate_fusion);
|
||||
|
||||
if (!candidate_fusion.empty()) {
|
||||
SetAnfNodeFusionId(candidate_fusion);
|
||||
} else {
|
||||
if (candidate_fusion.empty()) {
|
||||
return false;
|
||||
}
|
||||
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
|
||||
|
@ -741,13 +725,14 @@ bool BufferFusion::Run(const FuncGraphPtr &graph) {
|
|||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
||||
fusion_id_allocator.Init();
|
||||
if (MatchBufferFusionPattern(*kernel_graph)) {
|
||||
changed = FuseBufferFusionPattern(kernel_graph.get());
|
||||
}
|
||||
// clear fusion_id attr
|
||||
for (auto &node : graph->nodes()) {
|
||||
if (node != nullptr && node->isa<CNode>()) {
|
||||
AnfAlgo::EraseNodeAttr(kOpAttrFusionId, node);
|
||||
AnfAlgo::EraseNodeAttr(kAttrFusionId, node);
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "ir/anf.h"
|
||||
#include "pre_activate/common/pass.h"
|
||||
#include "pre_activate/common/fusion_id_allocator.h"
|
||||
#include "device/kernel_info.h"
|
||||
#include "kernel/kernel.h"
|
||||
#include "session/kernel_graph.h"
|
||||
|
@ -43,12 +44,24 @@ class BufferFusion : public Pass {
|
|||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record);
|
||||
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
|
||||
void GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const;
|
||||
bool ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos, int32_t fusion_id,
|
||||
const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const;
|
||||
bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const;
|
||||
bool MatchBufferFusionPattern(const session::KernelGraph &kernel_graph);
|
||||
bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const;
|
||||
|
||||
FusionIdAllocator fusion_id_allocator;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "pre_activate/common/fusion_id_allocator.h"
|
||||
#include "session/anf_runtime_algorithm.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
FusionIdAllocator::FusionIdAllocator() { fusion_id = 0; }
|
||||
|
||||
FusionIdAllocator::~FusionIdAllocator() {}
|
||||
|
||||
void FusionIdAllocator::Init() { fusion_id = 0; }
|
||||
|
||||
int32_t FusionIdAllocator::AllocateFusionId() {
|
||||
fusion_id++;
|
||||
return fusion_id;
|
||||
}
|
||||
|
||||
bool FusionIdAllocator::HasFusionIdAttr(const AnfNodePtr &node) { return AnfAlgo::HasNodeAttr(kAttrFusionId, node); }
|
||||
|
||||
int32_t FusionIdAllocator::GetFusionId(const AnfNodePtr &node) {
|
||||
if (HasFusionIdAttr(node)) {
|
||||
return AnfAlgo::GetNodeAttr<int32_t>(node, kAttrFusionId);
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void FusionIdAllocator::SetFusionId(const AnfNodePtr &node, int32_t id) {
|
||||
ValuePtr fusion_id_v = MakeValue(id);
|
||||
AnfAlgo::SetNodeAttr(kAttrFusionId, fusion_id_v, node);
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,42 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
|
||||
#define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
|
||||
|
||||
#include "ir/base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class FusionIdAllocator {
|
||||
public:
|
||||
FusionIdAllocator();
|
||||
virtual ~FusionIdAllocator();
|
||||
FusionIdAllocator(const FusionIdAllocator &in) = delete;
|
||||
FusionIdAllocator &operator=(const FusionIdAllocator &in) = delete;
|
||||
|
||||
void Init();
|
||||
int32_t AllocateFusionId();
|
||||
bool HasFusionIdAttr(const AnfNodePtr &node);
|
||||
int32_t GetFusionId(const AnfNodePtr &node);
|
||||
void SetFusionId(const AnfNodePtr &node, int32_t id);
|
||||
|
||||
private:
|
||||
int32_t fusion_id;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_
|
|
@ -165,6 +165,7 @@ constexpr auto kAttrFusion = "fusion";
|
|||
constexpr auto kAttrGroup = "group";
|
||||
constexpr auto kAttrOp = "op";
|
||||
constexpr auto kAttrIsTraining = "is_training";
|
||||
constexpr auto kAttrFusionId = "fusion_id";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
Loading…
Reference in New Issue