forked from mindspore-Ecosystem/mindspore
!4232 modify the condition of pattern match in bnupdate + eltwise fusion pass
Merge pull request !4232 from Etone.Chan/August
This commit is contained in:
commit
a0bfeedfa5
|
@ -27,15 +27,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(relu_input);
|
||||
auto getitem = relu_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
auto getitem = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(getitem);
|
||||
auto bnupdate = getitem->input(1);
|
||||
MS_EXCEPTION_IF_NULL(bnupdate);
|
||||
|
@ -68,10 +68,11 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) {
|
||||
MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
MatchBnupdateDoubleOutputEltwise(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -39,8 +39,8 @@ class BnupdateEltwiseFusionPass : public FusionBasePass {
|
|||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
void MatchBnupdateDoubleOutputEltwise(const CNodePtr &cnode, const AnfNodePtr &eltwise_input,
|
||||
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@ const int8_t MAX_ELTWISE_NUM = 3;
|
|||
const int8_t MIN_ELTWISE_SIZE = 2;
|
||||
const int8_t ELTWISE_INPUT_SIZE = 2;
|
||||
const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t ELTWISE_DOUBLE_OUTPUT_SIZE = 2;
|
||||
const int8_t CONV_DOUBLE_IN_INPUT_SIZE = 3;
|
||||
const int8_t CONV_QUART_IN_INPUT_SIZE = 5;
|
||||
const int8_t ELTWISE_USE = 1;
|
||||
|
|
Loading…
Reference in New Issue