fix bnupdate_eltwise_eltwise_fusion_pass
This commit is contained in:
parent
3a195af6c0
commit
fdbfdc56a5
|
@ -28,20 +28,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr size_t kEltwiseInputSize = 2;
|
||||
constexpr size_t kEltwiseOutputSize = 2;
|
||||
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
|
||||
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
|
||||
return true;
|
||||
}
|
||||
if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||
const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
|
@ -82,7 +68,8 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimAdd)) {
|
||||
|
|
Loading…
Reference in New Issue