!46218 add eltwise-broadcast fusion

Merge pull request !46218 from laiyongqiang/bug_fix
This commit is contained in:
i-robot 2022-11-30 06:24:09 +00:00 committed by Gitee
commit 961c45f07d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 38 additions and 3 deletions

View File

@ -31,7 +31,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
while (CheckEltWiseOrBroadCastNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
if (record.size() == MAX_ELTWISE_SIZE) {
break;
@ -40,7 +40,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
MS_EXCEPTION_IF_NULL(input_cnode);
eltwise_input = input_cnode->input(kIndex1);
}
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
if (CheckDoubleInEltWiseOrBroadCastNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
}
if (record.size() < MIN_ELTWISE_SIZE) {
@ -71,5 +71,38 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
}
}
}
bool EltwiseFusionPass::CheckEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_INPUT_SIZE;
}
bool EltwiseFusionPass::CheckDoubleInEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph,
const AnfNodePtr &node) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
}
} // namespace opt
} // namespace mindspore

View File

@ -38,6 +38,8 @@ class EltwiseFusionPass : public FusionBasePass {
private:
void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
bool CheckEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
bool CheckDoubleInEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
};
} // namespace opt
} // namespace mindspore

View File

@ -315,7 +315,7 @@ super_bar_config = {
"FusionOp_Conv2DBackpropInputD_ReluGradV2": [1, 0, 2]
},
"SkipDynamicCompileStatic": ["SoftmaxV2", "PRelu", "Trunc", "AccumulateNV2",
"SoftmaxCrossEntropyWithLogits"],
"SoftmaxCrossEntropyWithLogits", "ReduceMeanD", "SquareSumV1"],
# BroadcastTo: The name is occupied
# DynamicBroadcastTo: The name is occupied
# BatchToSpaceD: attr type is listIntnot listListInt