forked from mindspore-Ecosystem/mindspore
!33914 add BatchMatmulEltwiseFusionPass for buffer fusion
Merge pull request !33914 from yuchaojie/ub_fusion2
This commit is contained in:
commit
36616efc7f
|
@ -144,6 +144,7 @@ constexpr auto kSqrtOpName = "Sqrt";
|
|||
constexpr auto kRsqrtOpName = "Rsqrt";
|
||||
constexpr auto kRsqrtGradOpName = "RsqrtGrad";
|
||||
constexpr auto kErfOpName = "Erf";
|
||||
constexpr auto kDivOpName = "Div";
|
||||
constexpr auto kRealDivOpName = "RealDiv";
|
||||
constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR";
|
||||
constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay";
|
||||
|
@ -177,6 +178,8 @@ constexpr auto kSendOpName = "StreamSend";
|
|||
constexpr auto kRecvOpName = "StreamRecv";
|
||||
constexpr auto kRpcSendOpName = "RpcSend";
|
||||
constexpr auto kRpcRecvOpName = "RpcRecv";
|
||||
constexpr auto kReluOpName = "ReLU";
|
||||
constexpr auto kReluGradOpName = "ReluGrad";
|
||||
constexpr auto kReluV2OpName = "ReLUV2";
|
||||
constexpr auto kReluGradV2OpName = "ReluGradV2";
|
||||
constexpr auto kAddOpName = "Add";
|
||||
|
@ -335,6 +338,7 @@ constexpr auto kSubAndFilterOpName = "SubAndFilter";
|
|||
constexpr auto kPadAndShiftOpName = "PadAndShift";
|
||||
constexpr auto kSparseSoftmaxCrossEntropyWithLogitsOpName = "SparseSoftmaxCrossEntropyWithLogits";
|
||||
constexpr auto kOneHotOpName = "OneHot";
|
||||
constexpr auto kSigmoidOpName = "Sigmoid";
|
||||
constexpr auto kSoftmaxCrossEntropyWithLogitsOpName = "SoftmaxCrossEntropyWithLogits";
|
||||
constexpr auto kUniformCandidateSamplerOpName = "UniformCandidateSampler";
|
||||
constexpr auto kLogSoftmaxGradOpName = "LogSoftmaxGrad";
|
||||
|
|
|
@ -44,7 +44,9 @@ static const std::map<std::string, OptPassEnum> kPassCodeMap = {
|
|||
{std::to_string(39), OptPassEnum::ClipByNormNoDivSquareSumFusion},
|
||||
{std::to_string(42), OptPassEnum::MulAddNPass},
|
||||
{std::to_string(43), OptPassEnum::Resnet50DbnDwFusionPass},
|
||||
{std::to_string(44), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass},
|
||||
{std::to_string(45), OptPassEnum::MatmulConfusiontransposeUbFusion},
|
||||
{std::to_string(46), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass},
|
||||
{std::to_string(47), OptPassEnum::TbeBatchMatmulElementWiseFusionPass},
|
||||
};
|
||||
|
||||
|
|
|
@ -43,7 +43,9 @@ enum class OptPassEnum {
|
|||
ConvBnReduceFusionPass,
|
||||
MulAddNPass,
|
||||
Resnet50DbnDwFusionPass,
|
||||
BatchMatMulDropOutDoMaskV3DFusionPass,
|
||||
MatmulConfusiontransposeUbFusion,
|
||||
MatMulDropOutDoMaskV3DFusionPass,
|
||||
TbeBatchMatmulElementWiseFusionPass,
|
||||
Invalid,
|
||||
};
|
||||
|
|
|
@ -110,9 +110,9 @@
|
|||
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_eltwise_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_confusiontranspose_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/matmul_dropoutdomaskv3_add_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_dropoutdomaskv3_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_reducesum_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_fusion_pass.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h"
|
||||
|
@ -564,7 +564,7 @@ void AscendBackendUBFusionOptimization(const std::shared_ptr<session::KernelGrap
|
|||
ub_fusion_pm->AddPass(std::make_shared<DepthwiseConvEltwiseFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<MatmulConfusionTranposeFusionPass>(fusion_id_allocator));
|
||||
if (!graphkernel::GraphKernelFlags::GetInstance().IsEnableGraphKernel()) {
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulFusedMulAddFusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulEltwiseFusionPass>(fusion_id_allocator));
|
||||
}
|
||||
ub_fusion_pm->AddPass(std::make_shared<BatchMatmulDropoutDoMaskV3FusionPass>(fusion_id_allocator));
|
||||
ub_fusion_pm->AddPass(std::make_shared<UbPatternFusion>());
|
||||
|
|
|
@ -32,7 +32,9 @@ namespace opt {
|
|||
class BatchMatmulDropoutDoMaskV3FusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit BatchMatmulDropoutDoMaskV3FusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) {}
|
||||
: FusionBasePass("BatchMatmulDropoutDoMaskV3FusionPass", std::move(idAllocator)) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatMulDropOutDoMaskV3DFusionPass);
|
||||
}
|
||||
~BatchMatmulDropoutDoMaskV3FusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_eltwise_fusion_pass.h"
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "kernel/kernel_fusion.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "common/graph_kernel/graph_kernel_flags.h"
|
||||
#include "backend/common/optimizer/fusion_id_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr auto kAttrNoFusion = "no_fusion";
|
||||
|
||||
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const kernel::FusionType &fusion_type) {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 1; i <= input_num; ++i) {
|
||||
auto input = cnode->input(i);
|
||||
if (input->isa<CNode>() && common::AnfAlgo::GetCNodeName(input) == node_type &&
|
||||
AnfAlgo::GetFusionType(input) == fusion_type) {
|
||||
return input->cast<CNodePtr>();
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool BatchMatmulEltwiseFusionPass::MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet<AnfNodePtr> *record) {
|
||||
// bmm - eltwise - eltwise1
|
||||
const std::set<string> kElem1TypeList = {kAddOpName, kReluOpName, kFusedMulAddOpName};
|
||||
if (kElem1TypeList.find(common::AnfAlgo::GetCNodeName(eltwise1)) == kElem1TypeList.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(eltwise1);
|
||||
for (size_t i = 1; i <= input_num; ++i) {
|
||||
auto eltwise1_input = eltwise1->input(i);
|
||||
if (eltwise1_input->isa<CNode>() && MatchPattern2(eltwise1_input->cast<CNodePtr>(), record)) {
|
||||
record->insert(eltwise1);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BatchMatmulEltwiseFusionPass::MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record) {
|
||||
// bmm - eltwise
|
||||
const std::set<string> kElemTypeList = {kFusedMulAddOpName, kAddOpName, kDivOpName,
|
||||
kRealDivOpName, kReluOpName, kReluGradOpName};
|
||||
if (kElemTypeList.find(common::AnfAlgo::GetCNodeName(eltwise)) == kElemTypeList.end()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
|
||||
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm) || common::AnfAlgo::GetBooleanAttr(bmm, kAttrNoFusion)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
record->insert(eltwise);
|
||||
record->insert(bmm);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BatchMatmulEltwiseFusionPass::MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record) {
|
||||
// bmm - eltwise1(mul) - eltwise2(sigmoid) - eltwise(mul)
|
||||
if (common::AnfAlgo::GetCNodeName(eltwise) != kMulOpName) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::FusionType::ELEMWISE);
|
||||
if (eltwise2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::FusionType::ELEMWISE);
|
||||
if (eltwise1 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
|
||||
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
record->insert(eltwise);
|
||||
record->insert(eltwise2);
|
||||
record->insert(eltwise1);
|
||||
record->insert(bmm);
|
||||
return true;
|
||||
}
|
||||
|
||||
void BatchMatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
||||
const auto &node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
mindspore::HashSet<AnfNodePtr> record;
|
||||
if (MatchPattern1(cnode, &record) || MatchPattern2(cnode, &record) || MatchPattern3(cnode, &record)) {
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -13,8 +13,8 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_
|
||||
|
||||
#include "utils/hash_set.h"
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/fusion_base_pass.h"
|
||||
|
@ -27,20 +27,21 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
|
||||
class BatchMatmulEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit BatchMatmulFusedMulAddFusionPass(const FusionIdAllocatorPtr &idAllocator)
|
||||
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
|
||||
explicit BatchMatmulEltwiseFusionPass(const FusionIdAllocatorPtr &idAllocator)
|
||||
: FusionBasePass("BatchMatmulEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::TbeBatchMatmulElementWiseFusionPass);
|
||||
}
|
||||
~BatchMatmulFusedMulAddFusionPass() override = default;
|
||||
~BatchMatmulEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
private:
|
||||
void MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion);
|
||||
bool MatchPattern1(const CNodePtr &eltwise1, mindspore::HashSet<AnfNodePtr> *record);
|
||||
bool MatchPattern2(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record);
|
||||
bool MatchPattern3(const CNodePtr &eltwise, mindspore::HashSet<AnfNodePtr> *record);
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_FUSEDMULADD_FUSION_PASS_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_BUFFER_FUSION_PASS_BATCHMATMUL_ELTWISE_FUSION_PASS_H_
|
|
@ -1,58 +0,0 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "plugin/device/ascend/optimizer/buffer_fusion/batchmatmul_fusedmuladd_fusion_pass.h"
|
||||
#include "kernel/kernel_fusion.h"
|
||||
#include "backend/common/session/anf_runtime_algorithm.h"
|
||||
#include "include/common/utils/anfalgo.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/common/optimizer/fusion_id_allocator.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePtr &cnode, const session::KernelGraph &,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
auto batch_matmul = cnode->input(kIndex2);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
if (batch_matmul->isa<CNode>() && common::AnfAlgo::CheckPrimitiveType(batch_matmul, prim::kPrimBatchMatMul)) {
|
||||
mindspore::HashSet<AnfNodePtr> record{cnode, batch_matmul};
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
}
|
||||
}
|
||||
|
||||
void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
const auto &node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (!AnfUtils::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) {
|
||||
continue;
|
||||
}
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (common::AnfAlgo::GetCNodeName(cnode) == kFusedMulAddOpName) {
|
||||
MatchBatchMatmulFusedMulAdd(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
|
@ -30,7 +30,9 @@ namespace opt {
|
|||
class MatmulDropoutDoMaskV3AddFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit MatmulDropoutDoMaskV3AddFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {}
|
||||
: FusionBasePass("MatmulDropoutDoMaskV3AddFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatMulDropOutDoMaskV3DFusionPass);
|
||||
}
|
||||
~MatmulDropoutDoMaskV3AddFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
Loading…
Reference in New Issue