!47937 fix conv2dbackprop eltwise passes

Merge pull request !47937 from xulei/convback_eltwise_pass
This commit is contained in:
i-robot 2023-01-17 07:25:08 +00:00 committed by Gitee
commit 2d968044fc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 63 additions and 48 deletions

View File

@ -32,8 +32,9 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input) &&
common::AnfAlgo::GetCNodeName(eltwise_input) == kAddNOpName) {
const std::unordered_set<std::string> support_node_names{kAddNOpName, kAddOpName};
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input, {kernel::kPatternElemWise, kernel::kPatternBroadcast}) &&
support_node_names.find(common::AnfAlgo::GetCNodeName(eltwise_input)) != support_node_names.cend()) {
(void)record.insert(eltwise_input);
} else {
return;
@ -42,8 +43,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
MS_EXCEPTION_IF_NULL(manager);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
std::vector candidate_cb_node{input_cnode->input(kIndex2), input_cnode->input(kIndex1)};
for (const auto &cb_node : candidate_cb_node) {
std::vector candidate_cb_nodes{input_cnode->input(kIndex2), input_cnode->input(kIndex1)};
for (const auto &cb_node : candidate_cb_nodes) {
MS_EXCEPTION_IF_NULL(cb_node);
if (!cb_node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(cb_node)) {
return;

View File

@ -38,6 +38,13 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
return;
}
if (common::AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInputD)) {
// if cnode is ReluGradV2, we need do further check
// skip when output0 of Conv2DBackpropInputD is fp32, it may be slower
const std::unordered_set<TypeId> fp32_types{TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat};
if (common::AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName &&
fp32_types.count(AnfAlgo::GetOutputDeviceDataType(eltwise_input, kIndex0)) > 0) {
return;
}
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);
@ -57,8 +64,11 @@ void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::Ke
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::set<std::string> support_type = {kReluOpName, kPReluOpName, kLeakyReluOpName};
if (support_type.count(common::AnfAlgo::GetCNodeName(cnode)) > 0) {
std::set<std::string> support_node_names = {kReluGradV2OpName, kReluOpName, kPReluOpName, kLeakyReluOpName,
kAddOpName};
std::set<std::string> support_fusion_types = {kernel::kPatternElemWise, kernel::kPatternBroadcast};
if (support_node_names.count(common::AnfAlgo::GetCNodeName(cnode)) > 0 &&
support_fusion_types.count(AnfAlgo::GetFusionType(cnode)) > 0) {
MatchConv2DBackpropInputEltwise(cnode, candidate_fusion);
}
}

View File

@ -32,7 +32,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
MS_EXCEPTION_IF_NULL(candidate_fusion);
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);

View File

@ -22,49 +22,33 @@
namespace mindspore {
namespace opt {
bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types, size_t input_size,
size_t not_updatestate_size) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
fusion_types.find(AnfAlgo::GetFusionType(node)) != fusion_types.cend() &&
GetNotUpdateStateUserNums(kernel_graph, node) == not_updatestate_size && cnode->inputs().size() == input_size;
}
bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
bool FusionBasePass::CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types) {
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_INPUT_SIZE, ELTWISE_USE);
}
bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_MULTI_USE &&
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types) {
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_DOUBLE_IN_INPUT_SIZE, ELTWISE_USE);
}
bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types) {
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_INPUT_SIZE, ELTWISE_MULTI_USE);
}
size_t FusionBasePass::GetNotUpdateStateUserNums(const session::KernelGraph &kernel_graph,

View File

@ -18,6 +18,7 @@
#include <vector>
#include <string>
#include <utility>
#include <unordered_set>
#include "utils/hash_map.h"
#include "utils/hash_set.h"
#include "ir/anf.h"
@ -72,9 +73,28 @@ class FusionBasePass : public PassWithSwitch {
protected:
bool RunPass(const FuncGraphPtr &graph) override;
void SetRecordFusionId(const mindspore::HashSet<AnfNodePtr> &record);
bool CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
bool CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types, size_t input_size,
size_t not_updatestate_size);
bool CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
return CheckSingleInEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
}
bool CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types);
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
return CheckDoubleInEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
}
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types);
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
return CheckMultiOutputEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
}
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
const std::unordered_set<std::string> &fusion_types);
size_t GetNotUpdateStateUserNums(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) const;
FusionIdAllocatorPtr fusion_id_allocator;
};

View File

@ -39,7 +39,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
} else {
return;
}
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
if (record.size() == MULTI_ELTWISE_SIZE) {
break;

View File

@ -32,7 +32,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
MS_EXCEPTION_IF_NULL(candidate_fusion);
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
@ -54,7 +54,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
MS_EXCEPTION_IF_NULL(previous_input_cnode);
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
auto previous_size = record.size();
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, previous_eltwise_input)) {
(void)record.insert(previous_eltwise_input);
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_node);

View File

@ -30,7 +30,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
MS_EXCEPTION_IF_NULL(candidate_fusion);
mindspore::HashSet<AnfNodePtr> record{cnode};
auto eltwise_input = cnode->input(kIndex1);
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
(void)record.insert(eltwise_input);
auto input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);
@ -51,7 +51,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
MS_EXCEPTION_IF_NULL(previous_input_cnode);
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
auto previous_size = record.size();
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
while (CheckSingleInEltWiseNode(kernel_graph, previous_eltwise_input)) {
(void)record.insert(previous_eltwise_input);
MS_EXCEPTION_IF_NULL(previous_eltwise_input);
auto previous_node = previous_eltwise_input->cast<CNodePtr>();

View File

@ -30,7 +30,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
MS_EXCEPTION_IF_NULL(candidate_fusion);
mindspore::HashSet<AnfNodePtr> record{cnode};
auto write_input = cnode->input(kIndex1);
if (CheckEltWiseNode(kernel_graph, write_input)) {
if (CheckSingleInEltWiseNode(kernel_graph, write_input)) {
(void)record.insert(write_input);
auto input_cnode = write_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(input_cnode);