!33914 add BatchMatmulEltwiseFusionPass for buffer fusion

Merge pull request !33914 from yuchaojie/ub_fusion2
This commit is contained in:
i-robot 2022-05-11 09:47:11 +00:00 committed by Gitee
commit 36616efc7f
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 160 additions and 73 deletions

View File

@ -144,6 +144,7 @@ constexpr auto kSqrtOpName = "Sqrt";
constexpr auto kRsqrtOpName = "Rsqrt";
constexpr auto kRsqrtGradOpName = "RsqrtGrad";
constexpr auto kErfOpName = "Erf";
constexpr auto kDivOpName = "Div";
constexpr auto kRealDivOpName = "RealDiv";
constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR";
constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay";
@ -177,6 +178,8 @@ constexpr auto kSendOpName = "StreamSend";
constexpr auto kRecvOpName = "StreamRecv";
constexpr auto kRpcSendOpName = "RpcSend";
constexpr auto kRpcRecvOpName = "RpcRecv";
constexpr auto kReluOpName = "ReLU";
constexpr auto kReluGradOpName = "ReluGrad";
constexpr auto kReluV2OpName = "ReLUV2";
constexpr auto kReluGradV2OpName = "ReluGradV2";
constexpr auto kAddOpName = "Add";
@ -335,6 +338,7 @@ constexpr auto kSubAndFilterOpName = "SubAndFilter";
constexpr auto kPadAndShiftOpName = "PadAndShift";
constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits";
constexpr auto kOneHotOpName = "OneHot";
constexpr auto kSigmoidOpName = "Sigmoid";
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
constexpr auto kLogSoftmaxGradOpName = "LogSoftmaxGrad";

View File

@ -44,7 +44,9 @@ static const std::map<std::string, OptPassEnum> kPassCodeMap = {
{std::to_string(39), OptPassEnum::ClipByNormNoDivSquareSumFusion},
{std::to_string(42), OptPassEnum::MulAddNPass},
{std::to_string(43), OptPassEnum::Resnet50DbnDwFusionPass},
{std::to_string(44), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass},
{std::to_string(45), OptPassEnum::MatmulConfusiontransposeUbFusion},
{std::to_string(46), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass},
{std::to_string(47), OptPassEnum::TbeBatchMatmulElementWiseFusionPass},
};

View File

@ -43,7 +43,9 @@ enum class OptPassEnum {
ConvBnReduceFusionPass,
MulAddNPass,
Resnet50DbnDwFusionPass,
BatchMatMulDropOutDoMaskV3DFusionPass,
MatmulConfusiontransposeUbFusion,
MatMulDropOutDoMaskV3DFusionPass,
TbeBatchMatmulElementWiseFusionPass,
Invalid,
};

View File

@ -110,9 +110,9 @@
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_confusiontranspose_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_reducesum_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
@ -564,7 +564,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
}
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulDropoutDoMaskV3FusionPass>(fusion_id_allocator));
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());

View File

@ -32,7 +32,9 @@ namespace opt {
class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass {
public:
explicit BatchMatmulDropoutDoMaskV3FusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) {}
: FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass);
}
~BatchMatmulDropoutDoMaskV3FusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -0,0 +1,132 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h"
#include <set>
#include <string>
#include "kernel/kernel_fusion.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/core_ops.h"
#include "common/graph_kernel/graph_kernel_flags.h"
#include "backend/common/optimizer/fusion_id_allocator.h"
namespace mindspore {
namespace opt {
namespace {
constexpr auto kAttrNoFusion = "no_fusion";
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const kernel::FusionType &fusion_type) {
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 1; i <= input_num; ++i) {
auto input = cnode->input(i);
if (input->isa<CNode>() && common::AnfAlgo::GetCNodeName(input) == node_type &&
AnfAlgo::GetFusionType(input) == fusion_type) {
return input->cast<CNodePtr>();
}
}
return nullptr;
}
} // namespace
bool BatchMatmulEltwiseFusionPass::MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet<AnfNodePtr> *record) {
// bmm - eltwise - eltwise1
const std::set<string> kElem1TypeList = {kAddOpName, kReluOpName, kFusedMulAddOpName};
if (kElem1TypeList.find(common::AnfAlgo::GetCNodeName(eltwise1)) == kElem1TypeList.end()) {
return false;
}
auto input_num = common::AnfAlgo::GetInputTensorNum(eltwise1);
for (size_t i = 1; i <= input_num; ++i) {
auto eltwise1_input = eltwise1->input(i);
if (eltwise1_input->isa<CNode>() && MatchPattern2(eltwise1_input->cast<CNodePtr>(), record)) {
record->insert(eltwise1);
return true;
}
}
return false;
}
bool BatchMatmulEltwiseFusionPass::MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record) {
// bmm - eltwise
const std::set<string> kElemTypeList = {kFusedMulAddOpName, kAddOpName, kDivOpName,
kRealDivOpName, kReluOpName, kReluGradOpName};
if (kElemTypeList.find(common::AnfAlgo::GetCNodeName(eltwise)) == kElemTypeList.end()) {
return false;
}
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm) || common::AnfAlgo::GetBooleanAttr(bmm, kAttrNoFusion)) {
return false;
}
record->insert(eltwise);
record->insert(bmm);
return true;
}
bool BatchMatmulEltwiseFusionPass::MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record) {
// bmm - eltwise1(mul) - eltwise2(sigmoid) - eltwise(mul)
if (common::AnfAlgo::GetCNodeName(eltwise) != kMulOpName) {
return false;
}
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::FusionType::ELEMWISE);
if (eltwise2 == nullptr) {
return false;
}
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::FusionType::ELEMWISE);
if (eltwise1 == nullptr) {
return false;
}
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm)) {
return false;
}
record->insert(eltwise);
record->insert(eltwise2);
record->insert(eltwise1);
record->insert(bmm);
return true;
}
void BatchMatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
const auto &node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
mindspore::HashSet<AnfNodePtr> record;
if (MatchPattern1(cnode, &record) || MatchPattern2(cnode, &record) || MatchPattern3(cnode, &record)) {
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
}
}
}
} // namespace opt
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_
#include "utils/hash_set.h"
#include "plugin/device/ascend/optimizer/buffer_fusion/fusion_base_pass.h"
@ -27,20 +27,21 @@
namespace mindspore {
namespace opt {
class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
class BatchMatmulEltwiseFusionPass : public FusionBasePass {
public:
explicit BatchMatmulFusedMulAddFusionPass(const FusionIdAllocatorPtr &idAllocator)
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
explicit BatchMatmulEltwiseFusionPass(const FusionIdAllocatorPtr &idAllocator)
: FusionBasePass("BatchMatmulEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::TbeBatchMatmulElementWiseFusionPass);
}
~BatchMatmulFusedMulAddFusionPass() override = default;
~BatchMatmulEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
private:
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion);
bool MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet<AnfNodePtr> *record);
bool MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record);
bool MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record);
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_

View File

@ -1,58 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
#include "kernel/kernel_fusion.h"
#include "backend/common/session/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "mindspore/core/ops/core_ops.h"
#include "utils/ms_context.h"
#include "backend/common/optimizer/fusion_id_allocator.h"
namespace mindspore {
namespace opt {
void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto batch_matmul = cnode->input(kIndex2);
MS_EXCEPTION_IF_NULL(batch_matmul);
if (batch_matmul->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
mindspore::HashSet<AnfNodePtr> record{cnode, batch_matmul};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
}
void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
const auto &node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (common::AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) {
MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion);
}
}
}
} // namespace opt
} // namespace mindspore

View File

@ -30,7 +30,9 @@ namespace opt {
class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
public:
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass);
}
~MatmulDropoutDoMaskV3AddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;