forked from mindspore-Ecosystem/mindspore
!17472 add output_num for ub_fusion
From: @yuchaojie Reviewed-by: @jjfeing,@zhoufeng54 Signed-off-by: @zhoufeng54
This commit is contained in:
commit
730e1d3b65
|
@ -31,8 +31,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::unordered_set<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
|
@ -41,6 +39,10 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -48,6 +48,12 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
|
|||
}
|
||||
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONVLUTION) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> eltwise_output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(eltwise_output_used_num), eltwise_input);
|
||||
std::vector<int64_t> conv_output_used_num{SizeToLong(manager->node_users()[double_in_eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(conv_output_used_num), double_in_eltwise_input);
|
||||
(void)record.insert(double_in_eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
|
@ -48,6 +48,10 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
|
|||
}
|
||||
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONVLUTION) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
std::vector<int64_t> output_used_num{SizeToLong(manager->node_users()[eltwise_input].size())};
|
||||
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), eltwise_input);
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
|
Loading…
Reference in New Issue