forked from mindspore-Ecosystem/mindspore
!18764 move OutputUsedNum attr setting from fusion passes to ub_pattern_fusion
Merge pull request !18764 from yuchaojie/ub_fusion2
This commit is contained in:
commit
76ef295b5c
|
@ -37,8 +37,6 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
|
|||
auto batch_matmul = cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
if (batch_matmul->isa<CNode>() && AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[batch_matmul].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), batch_matmul);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, batch_matmul};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -53,29 +53,13 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
auto add = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(add);
|
||||
auto tuple_getitem = add->input(kIndex1);
|
||||
std::vector<int64_t> add_output_used_num;
|
||||
add_output_used_num.emplace_back(SizeToLong(manager->node_users()[add].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(add_output_used_num), add);
|
||||
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
||||
if (tuple_getitem->isa<CNode>() && AnfAlgo::GetCNodeName(tuple_getitem) == prim::kPrimTupleGetItem->name()) {
|
||||
auto getitem = tuple_getitem->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto bnupdate = getitem->input(kIndex1);
|
||||
auto bnupdate = getitem->input(kRealInputNodeIndexInTupleGetItem);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -40,19 +40,6 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
auto bnupdate = getitem->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
if (bnupdate->isa<CNode>() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) {
|
||||
std::vector<int64_t> output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0);
|
||||
for (auto out_getitem : manager->node_users()[bnupdate]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto input2 = out_getitem_ptr->input(kIndex2);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(input2));
|
||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -44,7 +44,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
auto double_in_eltwise_input = input_cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(double_in_eltwise_input);
|
||||
std::vector<int64_t> conv2d_bp_output_used_num;
|
||||
if (!double_in_eltwise_input->isa<CNode>() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input)) {
|
||||
return;
|
||||
}
|
||||
|
@ -53,8 +52,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
(void)record.insert(double_in_eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input);
|
||||
} else {
|
||||
auto double_in_eltwise_input_1 = input_cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(double_in_eltwise_input_1);
|
||||
|
@ -66,13 +63,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
(void)record.insert(double_in_eltwise_input_1);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
conv2d_bp_output_used_num.emplace_back(SizeToLong(manager->node_users()[double_in_eltwise_input_1].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv2d_bp_output_used_num), double_in_eltwise_input_1);
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> eltwise_output_used_num;
|
||||
eltwise_output_used_num.emplace_back(SizeToLong(manager->node_users()[input_cnode].size()));
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
|
||||
}
|
||||
|
||||
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
|
|
|
@ -39,10 +39,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -37,8 +37,6 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
|
|||
auto conv = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(conv);
|
||||
if (conv->isa<CNode>() && AnfAlgo::GetCNodeName(conv) == prim::kPrimConv2D->name()) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[conv].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, conv};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -48,12 +48,6 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
|
|||
}
|
||||
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> eltwise_output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
|
||||
std::vector<int64_t> conv_output_used_num{SizeToLong(manager->node_users()[double_in_eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv_output_used_num), double_in_eltwise_input);
|
||||
(void)record.insert(double_in_eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -48,10 +48,6 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
|
|||
}
|
||||
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -40,8 +40,6 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
|
|||
auto depthwise_conv = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(depthwise_conv);
|
||||
if (cnode->isa<CNode>() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[depthwise_conv].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
@ -51,8 +49,6 @@ void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnod
|
|||
auto relu = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(relu);
|
||||
if (cnode->isa<CNode>() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, relu};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -38,8 +38,6 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
|
|||
MS_EXCEPTION_IF_NULL(matmul);
|
||||
if (matmul->isa<CNode>() && (AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimMatMul) ||
|
||||
AnfAlgo::CheckPrimitiveType(matmul, prim::kPrimBatchMatMul))) {
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[matmul].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), matmul);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, matmul};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -36,8 +36,6 @@ void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const An
|
|||
if (fusion_id_allocator->HasFusionIdAttr(relu_input)) {
|
||||
return;
|
||||
}
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[relu_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input);
|
||||
std::unordered_set<AnfNodePtr> record{cnode, relu_input};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -36,8 +36,6 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckMultiOutputEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
std::vector<int64_t> output_used_num{SizeToInt(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
|
|
@ -64,8 +64,6 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
}
|
||||
}
|
||||
candidate_fusion->push_back(record);
|
||||
std::vector<int64_t> output_used_num{SizeToLong(kernel_graph.manager()->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -337,6 +337,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
|||
}
|
||||
}
|
||||
|
||||
void SetOutputUsedNumAttr(const session::KernelGraph &kernel_graph,
|
||||
const std::unordered_map<int64_t, BufferFusionInfo_t> &buffer_fusion_infos) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto &fusion_info : buffer_fusion_infos) {
|
||||
auto &fusion_nodes = fusion_info.second.anf_nodes;
|
||||
for (auto iter = fusion_nodes.begin(); iter != fusion_nodes.end() - 1; ++iter) {
|
||||
auto node = *iter;
|
||||
auto output_num = AnfAlgo::GetOutputTensorNum(node);
|
||||
std::vector<int64_t> output_used_num(output_num, 0);
|
||||
if (output_num == 1) {
|
||||
output_used_num[0] = SizeToLong(manager->node_users()[node].size());
|
||||
} else {
|
||||
for (auto out_getitem : manager->node_users()[node]) {
|
||||
MS_EXCEPTION_IF_NULL(out_getitem.first);
|
||||
if (!AnfAlgo::CheckPrimitiveType(out_getitem.first, prim::kPrimTupleGetItem)) {
|
||||
continue;
|
||||
}
|
||||
auto out_getitem_ptr = out_getitem.first->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_getitem_ptr);
|
||||
auto getitem_input2 = out_getitem_ptr->input(kInputNodeOutputIndexInTupleGetItem);
|
||||
auto output_idx = GetValue<int64_t>(GetValueNode(getitem_input2));
|
||||
output_used_num[output_idx] = SizeToLong(manager->node_users()[out_getitem.first].size());
|
||||
}
|
||||
}
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
|
||||
const AnfNodePtr &fusion_kernel) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
@ -410,6 +440,7 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
|||
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
||||
// Remove the fusion infos which will produce a circle if do fusion
|
||||
RemoveCircle(*kernel_graph, buffer_fusion_infos);
|
||||
SetOutputUsedNumAttr(*kernel_graph, *buffer_fusion_infos);
|
||||
|
||||
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
||||
buffer_fusion_info.second.kernel_build_info =
|
||||
|
|
Loading…
Reference in New Issue