!47937 fix conv2dbackprop eltwise passes
Merge pull request !47937 from xulei/convback_eltwise_pass
This commit is contained in:
commit
2d968044fc
|
@ -32,8 +32,9 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input) &&
|
||||
common::AnfAlgo::GetCNodeName(eltwise_input) == kAddNOpName) {
|
||||
const std::unordered_set<std::string> support_node_names{kAddNOpName, kAddOpName};
|
||||
if (CheckDoubleInEltWiseNode(kernel_graph, eltwise_input, {kernel::kPatternElemWise, kernel::kPatternBroadcast}) &&
|
||||
support_node_names.find(common::AnfAlgo::GetCNodeName(eltwise_input)) != support_node_names.cend()) {
|
||||
(void)record.insert(eltwise_input);
|
||||
} else {
|
||||
return;
|
||||
|
@ -42,8 +43,8 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
MS_EXCEPTION_IF_NULL(manager);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
std::vector candidate_cb_node{input_cnode->input(kIndex2), input_cnode->input(kIndex1)};
|
||||
for (const auto &cb_node : candidate_cb_node) {
|
||||
std::vector candidate_cb_nodes{input_cnode->input(kIndex2), input_cnode->input(kIndex1)};
|
||||
for (const auto &cb_node : candidate_cb_nodes) {
|
||||
MS_EXCEPTION_IF_NULL(cb_node);
|
||||
if (!cb_node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(cb_node)) {
|
||||
return;
|
||||
|
|
|
@ -38,6 +38,13 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
return;
|
||||
}
|
||||
if (common::AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInputD)) {
|
||||
// if cnode is ReluGradV2, we need do further check
|
||||
// skip when output0 of Conv2DBackpropInputD is fp32, it may be slower
|
||||
const std::unordered_set<TypeId> fp32_types{TypeId::kNumberTypeFloat32, TypeId::kNumberTypeFloat};
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName &&
|
||||
fp32_types.count(AnfAlgo::GetOutputDeviceDataType(eltwise_input, kIndex0)) > 0) {
|
||||
return;
|
||||
}
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
@ -57,8 +64,11 @@ void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::Ke
|
|||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::set<std::string> support_type = {kReluOpName, kPReluOpName, kLeakyReluOpName};
|
||||
if (support_type.count(common::AnfAlgo::GetCNodeName(cnode)) > 0) {
|
||||
std::set<std::string> support_node_names = {kReluGradV2OpName, kReluOpName, kPReluOpName, kLeakyReluOpName,
|
||||
kAddOpName};
|
||||
std::set<std::string> support_fusion_types = {kernel::kPatternElemWise, kernel::kPatternBroadcast};
|
||||
if (support_node_names.count(common::AnfAlgo::GetCNodeName(cnode)) > 0 &&
|
||||
support_fusion_types.count(AnfAlgo::GetFusionType(cnode)) > 0) {
|
||||
MatchConv2DBackpropInputEltwise(cnode, candidate_fusion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,7 +32,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
|
|
@ -22,49 +22,33 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types, size_t input_size,
|
||||
size_t not_updatestate_size) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
|
||||
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
fusion_types.find(AnfAlgo::GetFusionType(node)) != fusion_types.cend() &&
|
||||
GetNotUpdateStateUserNums(kernel_graph, node) == not_updatestate_size && cnode->inputs().size() == input_size;
|
||||
}
|
||||
|
||||
bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
|
||||
cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
|
||||
bool FusionBasePass::CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types) {
|
||||
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_INPUT_SIZE, ELTWISE_USE);
|
||||
}
|
||||
|
||||
bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
auto manager = kernel_graph.manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!node->isa<CNode>() || !AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) {
|
||||
return false;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_MULTI_USE &&
|
||||
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types) {
|
||||
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_DOUBLE_IN_INPUT_SIZE, ELTWISE_USE);
|
||||
}
|
||||
|
||||
bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types) {
|
||||
return CheckEltWiseNode(kernel_graph, node, fusion_types, ELTWISE_INPUT_SIZE, ELTWISE_MULTI_USE);
|
||||
}
|
||||
|
||||
size_t FusionBasePass::GetNotUpdateStateUserNums(const session::KernelGraph &kernel_graph,
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
#include "utils/hash_map.h"
|
||||
#include "utils/hash_set.h"
|
||||
#include "ir/anf.h"
|
||||
|
@ -72,9 +73,28 @@ class FusionBasePass : public PassWithSwitch {
|
|||
protected:
|
||||
bool RunPass(const FuncGraphPtr &graph) override;
|
||||
void SetRecordFusionId(const mindspore::HashSet<AnfNodePtr> &record);
|
||||
bool CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node);
|
||||
bool CheckEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types, size_t input_size,
|
||||
size_t not_updatestate_size);
|
||||
|
||||
bool CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
return CheckSingleInEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
|
||||
}
|
||||
bool CheckSingleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types);
|
||||
|
||||
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
return CheckDoubleInEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
|
||||
}
|
||||
bool CheckDoubleInEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types);
|
||||
|
||||
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) {
|
||||
return CheckMultiOutputEltWiseNode(kernel_graph, node, {kernel::kPatternElemWise});
|
||||
}
|
||||
bool CheckMultiOutputEltWiseNode(const session::KernelGraph &kernel_graph, const AnfNodePtr &node,
|
||||
const std::unordered_set<std::string> &fusion_types);
|
||||
|
||||
size_t GetNotUpdateStateUserNums(const session::KernelGraph &kernel_graph, const AnfNodePtr &node) const;
|
||||
FusionIdAllocatorPtr fusion_id_allocator;
|
||||
};
|
||||
|
|
|
@ -39,7 +39,7 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
} else {
|
||||
return;
|
||||
}
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
if (record.size() == MULTI_ELTWISE_SIZE) {
|
||||
break;
|
||||
|
|
|
@ -32,7 +32,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
@ -54,7 +54,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
MS_EXCEPTION_IF_NULL(previous_input_cnode);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
|
||||
auto previous_size = record.size();
|
||||
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
(void)record.insert(previous_eltwise_input);
|
||||
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_node);
|
||||
|
|
|
@ -30,7 +30,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
while (CheckEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, eltwise_input)) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
@ -51,7 +51,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
|
|||
MS_EXCEPTION_IF_NULL(previous_input_cnode);
|
||||
auto previous_eltwise_input = previous_input_cnode->input(kIndex1);
|
||||
auto previous_size = record.size();
|
||||
while (CheckEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
while (CheckSingleInEltWiseNode(kernel_graph, previous_eltwise_input)) {
|
||||
(void)record.insert(previous_eltwise_input);
|
||||
MS_EXCEPTION_IF_NULL(previous_eltwise_input);
|
||||
auto previous_node = previous_eltwise_input->cast<CNodePtr>();
|
||||
|
|
|
@ -30,7 +30,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode};
|
||||
auto write_input = cnode->input(kIndex1);
|
||||
if (CheckEltWiseNode(kernel_graph, write_input)) {
|
||||
if (CheckSingleInEltWiseNode(kernel_graph, write_input)) {
|
||||
(void)record.insert(write_input);
|
||||
auto input_cnode = write_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(input_cnode);
|
||||
|
|
Loading…
Reference in New Issue