!46218 add eltwise-broadcast fusion
Merge pull request !46218 from laiyongqiang/bug_fix
This commit is contained in:
commit
961c45f07d
|
@ -31,7 +31,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
|
|||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
while (CheckEltWiseOrBroadCastNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
if (record.size() == MAX_ELTWISE_SIZE) {
|
||||
break;
|
||||
|
@ -40,7 +40,7 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
|
|||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
eltwise_input = input_cnode->input(kIndex1);
|
||||
}
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
if (CheckDoubleInEltWiseOrBroadCastNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
}
|
||||
if (record.size() < MIN_ELTWISE_SIZE) {
|
||||
|
@ -71,5 +71,38 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool EltwiseFusionPass::CheckEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
|
||||
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
bool EltwiseFusionPass::CheckDoubleInEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph,
|
||||
const AnfNodePtr &node) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
|
||||
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,8 @@ class EltwiseFusionPass : public FusionBasePass {
|
|||
|
||||
private:
|
||||
void MatchEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
bool CheckEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
bool CheckDoubleInEltWiseOrBroadCastNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -315,7 +315,7 @@ super_bar_config = {
|
|||
"FusionOp_Conv2DBackpropInputD_ReluGradV2": [1, 0, 2]
|
||||
},
|
||||
"SkipDynamicCompileStatic": ["SoftmaxV2", "PRelu", "Trunc", "AccumulateNV2",
|
||||
"SoftmaxCrossEntropyWithLogits"],
|
||||
"SoftmaxCrossEntropyWithLogits", "ReduceMeanD", "SquareSumV1"],
|
||||
# BroadcastTo: The name is occupied
|
||||
# DynamicBroadcastTo: The name is occupied
|
||||
# BatchToSpaceD: attr type is listInt,not listListInt
|
||||
|
|
Loading…
Reference in New Issue