fix bnupdate_eltwise_eltwise_fusion_pass

This commit is contained in:
yuchaojie 2021-06-26 11:37:47 +08:00
parent 3a195af6c0
commit fdbfdc56a5
1 changed files with 2 additions and 15 deletions

View File

@ -28,20 +28,6 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace {
constexpr size_t kEltwiseInputSize = 2;
constexpr size_t kEltwiseOutputSize = 2;
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
return true;
}
if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) {
return true;
}
return false;
}
} // namespace
void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) { FusedNodeRecord *candidate_fusion) {
@ -82,7 +68,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) { AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
auto eltwise_input = cnode->input(kIndex1); auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input); MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) { if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {