Fix bnupdate_eltwise_eltwise ub fusion pass
This commit is contained in:
parent
dc8dbeb2b8
commit
c4bbf5a282
|
@ -27,6 +27,20 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
constexpr size_t kEltwiseInputSize = 2;
|
||||||
|
constexpr size_t kEltwiseOutputSize = 2;
|
||||||
|
bool CheckEltwiseInputAndOutputSize(const AnfNodePtr &node) {
|
||||||
|
if (AnfAlgo::GetInputTensorNum(node) == kEltwiseInputSize) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (AnfAlgo::GetOutputTensorNum(node) == kEltwiseOutputSize) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
||||||
const session::KernelGraph &kernel_graph,
|
const session::KernelGraph &kernel_graph,
|
||||||
FusedNodeRecord *candidate_fusion) {
|
FusedNodeRecord *candidate_fusion) {
|
||||||
|
@ -74,8 +88,9 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && CheckEltwiseInputAndOutputSize(cnode)) {
|
||||||
auto eltwise_input = cnode->input(1);
|
auto eltwise_input = cnode->input(1);
|
||||||
|
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||||
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
if (eltwise_input->isa<CNode>() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTensorAdd)) {
|
||||||
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue