add pass switch
Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
parent
59f03e8e21
commit
fb17c448c2
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
#include "mindspore/core/utils/ms_utils.h"
|
||||
|
||||
namespace mindspore::opt {
|
||||
namespace {
|
||||
constexpr char kMsAscendFusionSwitch[] = "MS_ASCEND_FUSION_SWITCH";
|
||||
} // namespace
|
||||
|
||||
bool PassWithSwitch::Run(const FuncGraphPtr &func_graph) {
|
||||
if (!PassSwitchManager::GetInstance().GetPassSwitch(name())) {
|
||||
return false; // false means no changed
|
||||
}
|
||||
|
||||
return RunPass(func_graph);
|
||||
}
|
||||
|
||||
AnfNodePtr PatternProcessPassWithSwitch::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
if (!PassSwitchManager::GetInstance().GetPassSwitch(name())) {
|
||||
return nullptr; // nullptr means no changed
|
||||
}
|
||||
|
||||
return PatternProcessPass::Run(func_graph, node);
|
||||
}
|
||||
|
||||
PassSwitchManager::PassSwitchManager() { SetSwitchFromEnv(); }
|
||||
|
||||
PassSwitchManager &PassSwitchManager::GetInstance() {
|
||||
static PassSwitchManager instance{};
|
||||
return instance;
|
||||
}
|
||||
|
||||
bool PassSwitchManager::GetPassSwitch(const std::string &pass_name) const {
|
||||
if (auto iter = env_pass_switch_.find(pass_name); iter != env_pass_switch_.end() && !env_switch_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(GetPassEnum(pass_name))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
enum OptPassEnum PassSwitchManager::GetPassEnum(const std::string &pass_name) const {
|
||||
if (auto iter = pass_enum_map_.find(pass_name); iter != pass_enum_map_.end()) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
return OptPassEnum::Invalid;
|
||||
}
|
||||
|
||||
void PassSwitchManager::SetSwitchFromEnv() {
|
||||
auto sw_env = common::GetEnv(kMsAscendFusionSwitch);
|
||||
env_switch_ = (sw_env != "OFF" && sw_env != "off" && sw_env != "0");
|
||||
}
|
||||
} // namespace mindspore::opt
|
|
@ -0,0 +1,70 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PassSwitchManager {
|
||||
public:
|
||||
static PassSwitchManager &GetInstance();
|
||||
void RegistPass(const std::string &pass_name) { env_pass_switch_.emplace(pass_name, true); }
|
||||
void RegistLicPass(const std::string &pass_name, enum OptPassEnum pass) { pass_enum_map_.emplace(pass_name, pass); }
|
||||
bool GetPassSwitch(const std::string &pass_name) const;
|
||||
|
||||
private:
|
||||
PassSwitchManager();
|
||||
~PassSwitchManager() = default;
|
||||
|
||||
enum OptPassEnum GetPassEnum(const std::string &pass_name) const;
|
||||
void SetSwitchFromEnv();
|
||||
|
||||
std::map<std::string, enum OptPassEnum> pass_enum_map_ = {};
|
||||
std::map<std::string, bool> env_pass_switch_ = {};
|
||||
bool env_switch_ = true;
|
||||
};
|
||||
|
||||
class PassWithSwitch : public Pass {
|
||||
public:
|
||||
explicit PassWithSwitch(const std::string &name = "pass") : Pass(name) {
|
||||
PassSwitchManager::GetInstance().RegistPass(name);
|
||||
}
|
||||
|
||||
virtual ~PassWithSwitch() = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
|
||||
protected:
|
||||
virtual bool RunPass(const FuncGraphPtr &func_graph) = 0;
|
||||
};
|
||||
|
||||
class PatternProcessPassWithSwitch : public PatternProcessPass {
|
||||
public:
|
||||
explicit PatternProcessPassWithSwitch(const std::string &name = "", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph) {
|
||||
PassSwitchManager::GetInstance().RegistPass(name);
|
||||
}
|
||||
|
||||
~PatternProcessPassWithSwitch() override = default;
|
||||
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_
|
|
@ -39,9 +39,6 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
|
|||
void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BatchMatmulFusedMulAddFusionPass)) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit BatchMatmulFusedMulAddFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {}
|
||||
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
|
||||
}
|
||||
~BatchMatmulFusedMulAddFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -54,9 +54,6 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
|
|||
|
||||
void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BnupdateEltwiseEltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {}
|
||||
: FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BnupdateEltwiseEltwiseFusionPass);
|
||||
}
|
||||
~BnupdateEltwiseEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -43,9 +43,6 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
|
|||
|
||||
void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BnupdateEltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class BnupdateEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {}
|
||||
: FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BnupdateEltwiseFusionPass);
|
||||
}
|
||||
~BnupdateEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -65,9 +65,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
|
|||
|
||||
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Conv2DBackpropEltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {}
|
||||
: FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Conv2DBackpropEltwiseFusionPass);
|
||||
}
|
||||
~Conv2DBackpropEltwiseEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -43,9 +43,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
|
|||
|
||||
void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Conv2DBackpropEltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class Conv2DBackpropEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {}
|
||||
: FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Conv2DBackpropEltwiseFusionPass);
|
||||
}
|
||||
~Conv2DBackpropEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -40,9 +40,6 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
|
|||
|
||||
void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ConvBnReduceFusionPass)) {
|
||||
return;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class ConvBnReduceFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("ConvBnReduceFusionPass", idAllocator) {}
|
||||
: FusionBasePass("ConvBnReduceFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ConvBnReduceFusionPass);
|
||||
}
|
||||
~ConvBnReduceFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -51,9 +51,6 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
|
|||
void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::EltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
std::reverse(node_list.begin(), node_list.end());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -33,7 +33,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
|
||||
class EltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {}
|
||||
explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::EltwiseFusionPass);
|
||||
}
|
||||
~EltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h"
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
|
@ -105,7 +106,7 @@ bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_gra
|
|||
return true;
|
||||
}
|
||||
|
||||
bool FusionBasePass::Run(const FuncGraphPtr &graph) {
|
||||
bool FusionBasePass::RunPass(const FuncGraphPtr &graph) {
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
|
|
|
@ -19,14 +19,14 @@
|
|||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <utility>
|
||||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pass.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "backend/kernel_compiler/kernel.h"
|
||||
#include "backend/session/kernel_graph.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -55,15 +55,15 @@ struct BufferFusionInfo_t {
|
|||
kernel::KernelBuildInfoPtr kernel_build_info;
|
||||
};
|
||||
|
||||
class FusionBasePass : public Pass {
|
||||
class FusionBasePass : public PassWithSwitch {
|
||||
public:
|
||||
FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator)
|
||||
: Pass(name), fusion_id_allocator(idAllocator) {}
|
||||
: PassWithSwitch(name), fusion_id_allocator(std::move(idAllocator)) {}
|
||||
~FusionBasePass() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph);
|
||||
|
||||
protected:
|
||||
bool RunPass(const FuncGraphPtr &graph) override;
|
||||
virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) = 0;
|
||||
void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record);
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -41,11 +40,6 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
|
|||
void MatmulConfusionTranposeFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulConfusiontransposeUbFusion)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class MatmulConfusionTranposeFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit MatmulConfusionTranposeFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {}
|
||||
: FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulConfusiontransposeUbFusion);
|
||||
}
|
||||
~MatmulConfusionTranposeFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -39,10 +39,6 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulEltwiseFusionPass)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
for (auto &node : node_list) {
|
||||
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class MatmulEltwiseFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {}
|
||||
: FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulEltwiseFusionPass);
|
||||
}
|
||||
~MatmulEltwiseFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/optimizer/common/fusion_id_allocator.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -57,9 +56,6 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
|
|||
void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
|
||||
FusedNodeRecord *candidate_fusion) {
|
||||
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MultiOutputFusionPass)) {
|
||||
return;
|
||||
}
|
||||
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
||||
std::reverse(node_list.begin(), node_list.end());
|
||||
for (auto &node : node_list) {
|
||||
|
|
|
@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
|||
class MultiOutputFusionPass : public FusionBasePass {
|
||||
public:
|
||||
explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator)
|
||||
: FusionBasePass("MultiOutputFusionPass", idAllocator) {}
|
||||
: FusionBasePass("MultiOutputFusionPass", idAllocator) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MultiOutputFusionPass);
|
||||
}
|
||||
~MultiOutputFusionPass() override = default;
|
||||
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;
|
||||
|
||||
|
|
|
@ -514,7 +514,7 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionIn
|
|||
return true;
|
||||
}
|
||||
|
||||
bool UbPatternFusion::Run(const FuncGraphPtr &graph) {
|
||||
bool UbPatternFusion::RunPass(const FuncGraphPtr &graph) {
|
||||
bool changed = false;
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
||||
|
|
|
@ -32,13 +32,14 @@ namespace mindspore {
|
|||
namespace opt {
|
||||
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
|
||||
|
||||
class UbPatternFusion : public Pass {
|
||||
class UbPatternFusion : public PassWithSwitch {
|
||||
public:
|
||||
UbPatternFusion() : Pass("TbeBufferFusion") {}
|
||||
UbPatternFusion() : PassWithSwitch("TbeBufferFusion") {}
|
||||
~UbPatternFusion() override = default;
|
||||
bool Run(const FuncGraphPtr &graph) override;
|
||||
|
||||
private:
|
||||
bool RunPass(const FuncGraphPtr &graph) override;
|
||||
|
||||
void GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
||||
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const;
|
||||
bool ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,
|
||||
|
|
|
@ -23,7 +23,6 @@
|
|||
#include "base/core_ops.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -74,10 +73,6 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Resnet50DbnDwFusionPass)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto conv_back_filter = CheckAnfNodeIfCNodeAndInputSize(node, kConv2DBackpropFilterInputNum);
|
||||
auto bnreduce_grad = CheckAnfNodeIfCNodeAndInputSize(conv_back_filter->input(kIndex1), kBNTrainingReduceGradInputNum);
|
||||
if (!CheckSupported(conv_back_filter)) {
|
||||
|
|
|
@ -17,14 +17,16 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BN_REDUCE_GRAD_CONV2D_BACKPROP_FILTER_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class BNReduceGradConv2dBackpropFilterFusion : public PatternProcessPass {
|
||||
class BNReduceGradConv2dBackpropFilterFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit BNReduceGradConv2dBackpropFilterFusion(bool multigraph = true)
|
||||
: PatternProcessPass("bn_reduce_grad_conv2d_backprop_filter_fusion", multigraph) {}
|
||||
: PatternProcessPassWithSwitch("bn_reduce_grad_conv2d_backprop_filter_fusion", multigraph) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Resnet50DbnDwFusionPass);
|
||||
}
|
||||
~BNReduceGradConv2dBackpropFilterFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -46,10 +45,6 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
|
|||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ClipByNormNoDivSquareSumFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
BaseRef &input_gnode = (*equiv)[input_];
|
||||
BaseRef &constant_select_gnode = (*equiv)[constant_select_];
|
||||
BaseRef &constant_greater_gnode = (*equiv)[constant_greater_];
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -27,14 +27,15 @@ constexpr auto kConstantSelectVarName = "constant_select";
|
|||
constexpr auto kConstantGreaterVarName = "constant_greater";
|
||||
constexpr auto kConstantMaximumVarName = "constant_maximum";
|
||||
|
||||
class ClipByNormNoDivSquareSumFusion : public PatternProcessPass {
|
||||
class ClipByNormNoDivSquareSumFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true)
|
||||
: PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) {
|
||||
: PatternProcessPassWithSwitch("clip_by_norm_no_div_square_sum_fusion", multigraph) {
|
||||
input_ = std::make_shared<Var>(kInputVarName);
|
||||
constant_select_ = std::make_shared<Var>(kConstantSelectVarName);
|
||||
constant_greater_ = std::make_shared<Var>(kConstantGreaterVarName);
|
||||
constant_maximum_ = std::make_shared<Var>(kConstantMaximumVarName);
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ClipByNormNoDivSquareSumFusion);
|
||||
}
|
||||
~ClipByNormNoDivSquareSumFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "utils/utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/utils.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -33,10 +32,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
|
|||
const EquivPtr &equiv) const {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulBiasaddFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
|
||||
if (matmul == nullptr || !matmul->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"
|
||||
|
|
|
@ -17,17 +17,19 @@
|
|||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
|
||||
|
||||
#include <memory>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MatmulBiasaddFusion : public PatternProcessPass {
|
||||
class MatmulBiasaddFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {
|
||||
explicit MatmulBiasaddFusion(bool multigraph = true)
|
||||
: PatternProcessPassWithSwitch("matmul_biasadd_fusion", multigraph) {
|
||||
x0_ = std::make_shared<Var>();
|
||||
x1_ = std::make_shared<Var>();
|
||||
x2_ = std::make_shared<Var>();
|
||||
matmul_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMatMul->name()));
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulBiasaddFusion);
|
||||
}
|
||||
|
||||
~MatmulBiasaddFusion() override = default;
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <string>
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -21,7 +21,6 @@
|
|||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -72,10 +71,6 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MulAddFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr mul = nullptr;
|
||||
size_t mul_index = 0;
|
||||
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MulAddFusion : public PatternProcessPass {
|
||||
class MulAddFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {}
|
||||
explicit MulAddFusion(bool multigraph = true) : PatternProcessPassWithSwitch("mul_add_fusion", multigraph) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MulAddFusion);
|
||||
}
|
||||
~MulAddFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
|
@ -16,11 +16,9 @@
|
|||
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "frontend/optimizer/opt.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -60,10 +58,6 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MulAddNPass)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto addn = node->cast<CNodePtr>();
|
||||
if (addn == nullptr) {
|
||||
return nullptr;
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class MulAddNFusion : public PatternProcessPass {
|
||||
class MulAddNFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {}
|
||||
explicit MulAddNFusion(bool multigraph = true) : PatternProcessPassWithSwitch("mul_addn_fusion", multigraph) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MulAddNPass);
|
||||
}
|
||||
~MulAddNFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -68,10 +67,6 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ReshapeTransposeFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
|
||||
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
|
||||
auto new_node = func_graph->NewCNode(inputs);
|
||||
|
|
|
@ -24,14 +24,16 @@
|
|||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class ReshapeTransposeFusion : public PatternProcessPass {
|
||||
class ReshapeTransposeFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) {
|
||||
explicit ReshapeTransposeFusion(bool multigraph = true)
|
||||
: PatternProcessPassWithSwitch("reshape_transpose_fusion", multigraph) {
|
||||
input_varptr_ = std::make_shared<Var>();
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ReshapeTransposeFusion);
|
||||
}
|
||||
~ReshapeTransposeFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
@ -42,5 +44,4 @@ class ReshapeTransposeFusion : public PatternProcessPass {
|
|||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "ir/primitive.h"
|
||||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "backend/optimizer/common/helper.h"
|
||||
#include "runtime/device/kernel_info.h"
|
||||
#include "utils/trace_base.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -100,10 +99,6 @@ const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNo
|
|||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::SquareSumFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
CNodePtr sum = nullptr;
|
||||
AnfNodePtr square_anf = nullptr;
|
||||
CNodePtr square = nullptr;
|
||||
|
|
|
@ -16,13 +16,15 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_
|
||||
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class SquareSumFusion : public PatternProcessPass {
|
||||
class SquareSumFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {}
|
||||
explicit SquareSumFusion(bool multigraph = true) : PatternProcessPassWithSwitch("square_sum_fusion", multigraph) {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::SquareSumFusion);
|
||||
}
|
||||
~SquareSumFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include "utils/utils.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "base/core_ops.h"
|
||||
#include "runtime/device/ascend/lic_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -50,10 +49,6 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
|||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(equiv);
|
||||
|
||||
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::TransposeReshapeFusion)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
|
||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum);
|
||||
|
|
|
@ -24,14 +24,16 @@
|
|||
#include "ir/anf.h"
|
||||
#include "backend/optimizer/common/pattern_engine.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TransposeReshapeFusion : public PatternProcessPass {
|
||||
class TransposeReshapeFusion : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) {
|
||||
explicit TransposeReshapeFusion(bool multigraph = true)
|
||||
: PatternProcessPassWithSwitch("transpose_reshape_fusion", multigraph) {
|
||||
input_varptr_ = std::make_shared<Var>();
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::TransposeReshapeFusion);
|
||||
}
|
||||
~TransposeReshapeFusion() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
|
|
@ -40,8 +40,7 @@ const BaseRef PatternProcessPass::DefinePattern() const {
|
|||
|
||||
void PatternProcessPass::Build() {
|
||||
VarPtr fg = std::make_shared<Var>("RootG");
|
||||
BaseRef pattern = std::move(DefinePattern());
|
||||
pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
|
||||
pattern_ = SexpToNode(DefinePattern(), fg, primitive_vars_.get(), multigraph_);
|
||||
}
|
||||
|
||||
AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
|
||||
|
|
|
@ -33,7 +33,7 @@ class Pass {
|
|||
explicit Pass(const std::string &name = "pass") : name_(name) {}
|
||||
virtual ~Pass() = default;
|
||||
virtual bool Run(const FuncGraphPtr &func_graph) = 0;
|
||||
virtual std::string name() const { return name_; }
|
||||
const std::string &name() const { return name_; }
|
||||
void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm; }
|
||||
const CacheManagerPtr &GetCacheManager() const { return cache_manager_; }
|
||||
|
||||
|
|
|
@ -45,6 +45,7 @@ enum class OptPassEnum {
|
|||
Resnet50DbnDwFusionPass,
|
||||
MatmulConfusiontransposeUbFusion,
|
||||
TbeBatchMatmulElementWiseFusionPass,
|
||||
Invalid,
|
||||
};
|
||||
|
||||
class LicManager {
|
||||
|
|
|
@ -0,0 +1,301 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "common/backend_common_test.h"
|
||||
#define private public
|
||||
#include "backend/optimizer/ascend/ascend_pass_control.h"
|
||||
#undef private
|
||||
|
||||
namespace {
|
||||
constexpr char kMsAscendFusionSwitch[] = "MS_ASCEND_FUSION_SWITCH";
|
||||
} // namespace
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class PlantPass : public Pass {
|
||||
public:
|
||||
PlantPass() : Pass("plant") {}
|
||||
~PlantPass() override = default;
|
||||
bool GetRunStatus() const { return is_run_; }
|
||||
bool Run(const FuncGraphPtr &) override {
|
||||
is_run_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_run_ = false;
|
||||
};
|
||||
|
||||
class PlantPatternPass : public PatternProcessPass {
|
||||
public:
|
||||
PlantPatternPass() : PatternProcessPass("plant_pattern") { is_run_ = false; }
|
||||
~PlantPatternPass() override = default;
|
||||
bool GetRunStatus() const { return is_run_; }
|
||||
|
||||
const BaseRef DefinePattern() const override { return BaseRef({std::make_shared<Var>()}); }
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override {
|
||||
is_run_ = true;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline static bool is_run_ = false;
|
||||
};
|
||||
|
||||
class TestPass : public PassWithSwitch {
|
||||
public:
|
||||
TestPass() : PassWithSwitch("test") {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::DereluFusion);
|
||||
}
|
||||
~TestPass() override = default;
|
||||
bool GetRunStatus() const { return is_run_; }
|
||||
bool RunPass(const FuncGraphPtr &) override {
|
||||
is_run_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_run_ = false;
|
||||
};
|
||||
|
||||
class TestPatternPass : public PatternProcessPassWithSwitch {
|
||||
public:
|
||||
TestPatternPass() : PatternProcessPassWithSwitch("test_pattern") {
|
||||
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::DereluFusion);
|
||||
is_run_ = false;
|
||||
}
|
||||
~TestPatternPass() override = default;
|
||||
bool GetRunStatus() const { return is_run_; }
|
||||
|
||||
const BaseRef DefinePattern() const override { return BaseRef({std::make_shared<Var>()}); }
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override {
|
||||
is_run_ = true;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline static bool is_run_ = false;
|
||||
};
|
||||
|
||||
class TestAscendPassControl : public UT::Common {
|
||||
public:
|
||||
TestAscendPassControl() = default;
|
||||
|
||||
void TearDown() override {
|
||||
(void)unsetenv(kMsAscendFusionSwitch);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
}
|
||||
};
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: no MS_ASCEND_FUSION_SWITCH set and run pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_no_env_for_pass) {
|
||||
(void)unsetenv(kMsAscendFusionSwitch);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: set MS_ASCEND_FUSION_SWITCH as "on" and run pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_on_for_pass_0) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "on", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: set invalid MS_ASCEND_FUSION_SWITCH and run pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_on_for_pass_1) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "invalidxxxxxxxx", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: set MS_ASCEND_FUSION_SWITCH as "0" and run pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pass_0) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "0", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: set MS_ASCEND_FUSION_SWITCH as "off" and run pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pass_1) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "off", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: set MS_ASCEND_FUSION_SWITCH as "OFF" and run pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pass_2) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "OFF", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
pass.Run(nullptr);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(nullptr);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: no MS_ASCEND_FUSION_SWITCH set and run pattern pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_no_env_for_pattern_pass) {
|
||||
(void)unsetenv(kMsAscendFusionSwitch);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: MS_ASCEND_FUSION_SWITCH set on and run pattern pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_on_for_pattern_pass) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "on", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: MS_ASCEND_FUSION_SWITCH set invalid and run pattern pass
|
||||
/// Expectation: switch pass run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_invalid_for_pattern_pass) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "invalid_xxasdasdasfsldjmg", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_TRUE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: MS_ASCEND_FUSION_SWITCH set off and run pattern pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_0) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "off", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: MS_ASCEND_FUSION_SWITCH set OFF and run pattern pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_1) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "OFF", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
|
||||
/// Feature: Pass Switch
|
||||
/// Description: MS_ASCEND_FUSION_SWITCH set 0 and run pattern pass
|
||||
/// Expectation: switch pass dont run, plant pass run
|
||||
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_2) {
|
||||
(void)setenv(kMsAscendFusionSwitch, "0", 1);
|
||||
PassSwitchManager::GetInstance().SetSwitchFromEnv();
|
||||
TestPatternPass pass;
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
FuncGraphPtr graph = std::make_shared<FuncGraph>();
|
||||
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
|
||||
pass.Run(graph, node);
|
||||
ASSERT_FALSE(pass.GetRunStatus());
|
||||
PlantPatternPass plant_pass;
|
||||
ASSERT_FALSE(plant_pass.GetRunStatus());
|
||||
plant_pass.Run(graph, node);
|
||||
ASSERT_TRUE(plant_pass.GetRunStatus());
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue