!18764 move OutputUsedNum attr setting from fusion passes to ub_pattern_fusion

Merge pull request !18764 from yuchaojie/ub_fusion2
This commit is contained in:
i-robot 2021-06-24 02:50:19 +00:00 committed by Gitee
commit 76ef295b5c
14 changed files with 32 additions and 68 deletions

View File

@ -37,8 +37,6 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
auto batch_matmul = cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(batch_matmul);
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), batch_matmul);
std::unordered_set<AnfNodePtr> record{cnode, batch_matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -53,29 +53,13 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
auto add = relu_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(add);
auto tuple_getitem = add->input(kIndex1);
std::vector<int64_t> add_output_used_num;
add_output_used_num.emplace_back(SizeToLong(manager->node_users()[add].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(add_output_used_num), add);
MS_EXCEPTION_IF_NULL(tuple_getitem);
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
auto getitem = tuple_getitem->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto bnupdate = getitem->input(kIndex1);
auto bnupdate = getitem->input(kRealInputNodeIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
}
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -40,19 +40,6 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
auto bnupdate = getitem->input(kIndex1);
MS_EXCEPTION_IF_NULL(bnupdate);
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
for (auto out_getitem : manager->node_users()[bnupdate]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto input2 = out_getitem_ptr->input(kIndex2);
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
}
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -44,7 +44,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(input_cnode);
auto double_in_eltwise_input = input_cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
std::vector<int64_t> conv2d_bp_output_used_num;
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
return;
}
@ -53,8 +52,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
(void)record.insert(double_in_eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input);
} else {
auto double_in_eltwise_input_1 = input_cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
@ -66,13 +63,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
(void)record.insert(double_in_eltwise_input_1);
candidate_fusion->push_back(record);
SetRecordFusionId(record);
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input_1].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input_1);
}
}
std::vector<int64_t> eltwise_output_used_num;
eltwise_output_used_num.emplace_back(SizeToLong(manager->node_users()[input_cnode].size()));
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
}
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,

View File

@ -39,10 +39,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
return;
}
if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -37,8 +37,6 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
auto conv = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(conv);
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);
std::unordered_set<AnfNodePtr> record{cnode, conv};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -48,12 +48,6 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
}
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> eltwise_output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
std::vector<int64_t> conv_output_used_num{SizeToLong(manager->node_users()[double_in_eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv_output_used_num), double_in_eltwise_input);
(void)record.insert(double_in_eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -48,10 +48,6 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -40,8 +40,6 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
auto depthwise_conv = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(depthwise_conv);
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[depthwise_conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
@ -51,8 +49,6 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
auto relu = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(relu);
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
std::unordered_set<AnfNodePtr> record{cnode, relu};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -38,8 +38,6 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
MS_EXCEPTION_IF_NULL(matmul);
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul);
std::unordered_set<AnfNodePtr> record{cnode, matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -36,8 +36,6 @@ void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const An
if (fusion_id_allocator->HasFusionIdAttr(relu_input)) {
return;
}
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input);
std::unordered_set<AnfNodePtr> record{cnode, relu_input};
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -36,8 +36,6 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckMultiOutputEltWiseNode(kernel_graph, eltwise_input)) {
std::vector<int64_t> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);

View File

@ -64,8 +64,6 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
}
}
candidate_fusion->push_back(record);
std::vector<int64_t> output_used_num{SizeToLong(kernel_graph.manager()->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
SetRecordFusionId(record);
}
}

View File

@ -337,6 +337,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
}
}
void SetOutputUsedNumAttr(const session::KernelGraph &kernel_graph,
const std::unordered_map<int64_t, BufferFusionInfo_t> &buffer_fusion_infos) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto &fusion_info : buffer_fusion_infos) {
auto &fusion_nodes = fusion_info.second.anf_nodes;
for (auto iter = fusion_nodes.begin(); iter != fusion_nodes.end() - 1; ++iter) {
auto node = *iter;
auto output_num = AnfAlgo::GetOutputTensorNum(node);
std::vector<int64_t> output_used_num(output_num, 0);
if (output_num == 1) {
output_used_num[0] = SizeToLong(manager->node_users()[node].size());
} else {
for (auto out_getitem : manager->node_users()[node]) {
MS_EXCEPTION_IF_NULL(out_getitem.first);
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
continue;
}
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
auto output_idx = GetValue<int64_t>(GetValueNode(getitem_input2));
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
}
}
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), node);
}
}
}
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
const AnfNodePtr &fusion_kernel) {
MS_EXCEPTION_IF_NULL(kernel_graph);
@ -410,6 +440,7 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
// Remove the fusion infos which will produce a circle if do fusion
RemoveCircle(*kernel_graph, buffer_fusion_infos);
SetOutputUsedNumAttr(*kernel_graph, *buffer_fusion_infos);
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
buffer_fusion_info.second.kernel_build_info =