match ub fusion pass
This commit is contained in:
parent
ae7175f522
commit
432192b1d8
|
@ -70,10 +70,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
|||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -65,10 +65,8 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
|
|||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,11 +74,8 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
|
|||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) {
|
||||
if (eltwise_input->isa<CNode>() &&
|
||||
AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
||||
}
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
|
||||
}
|
||||
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
|
||||
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
|
||||
|
|
|
@ -55,6 +55,7 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
std::reverse(node_list.begin(), node_list.end());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
|
|
|
@ -73,6 +73,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
std::reverse(node_list.begin(), node_list.end());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
|
|
Loading…
Reference in New Issue