!17472 add output_num for ub_fusion

From: @yuchaojie
Reviewed-by: @jjfeing,@zhoufeng54
Signed-off-by: @zhoufeng54
This commit is contained in:
mindspore-ci-bot 2021-06-03 14:42:54 +08:00 committed by Gitee
commit 730e1d3b65
3 changed files with 14 additions and 2 deletions

View File

@ -31,8 +31,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::unordered_set<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);
@ -41,6 +39,10 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
return;
}
if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -48,6 +48,12 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
}
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> eltwise_output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
std::vector<int64_t> conv_output_used_num{SizeToLong(manager->node_users()[double_in_eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv_output_used_num), double_in_eltwise_input);
(void)record.insert(double_in_eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);

View File

@ -48,6 +48,10 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);