add pass switch

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
This commit is contained in:
zhoufeng 2021-10-29 16:37:00 +08:00
parent 59f03e8e21
commit fb17c448c2
50 changed files with 526 additions and 128 deletions

View File

@ -0,0 +1,71 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/ascend_pass_control.h"
#include "mindspore/core/utils/ms_utils.h"
namespace mindspore::opt {
namespace {
constexpr char kMsAscendFusionSwitch[] = "MS_ASCEND_FUSION_SWITCH";
} // namespace
bool PassWithSwitch::Run(const FuncGraphPtr &func_graph) {
if (!PassSwitchManager::GetInstance().GetPassSwitch(name())) {
return false; // false means no changed
}
return RunPass(func_graph);
}
AnfNodePtr PatternProcessPassWithSwitch::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
if (!PassSwitchManager::GetInstance().GetPassSwitch(name())) {
return nullptr; // nullptr means no changed
}
return PatternProcessPass::Run(func_graph, node);
}
PassSwitchManager::PassSwitchManager() { SetSwitchFromEnv(); }
PassSwitchManager &PassSwitchManager::GetInstance() {
static PassSwitchManager instance{};
return instance;
}
bool PassSwitchManager::GetPassSwitch(const std::string &pass_name) const {
if (auto iter = env_pass_switch_.find(pass_name); iter != env_pass_switch_.end() && !env_switch_) {
return false;
}
if (!LicManager::GetInstance().GetPassSwitch(GetPassEnum(pass_name))) {
return false;
}
return true;
}
enum OptPassEnum PassSwitchManager::GetPassEnum(const std::string &pass_name) const {
if (auto iter = pass_enum_map_.find(pass_name); iter != pass_enum_map_.end()) {
return iter->second;
}
return OptPassEnum::Invalid;
}
void PassSwitchManager::SetSwitchFromEnv() {
auto sw_env = common::GetEnv(kMsAscendFusionSwitch);
env_switch_ = (sw_env != "OFF" && sw_env != "off" && sw_env != "0");
}
} // namespace mindspore::opt

View File

@ -0,0 +1,70 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_
#include <string>
#include <map>
#include "backend/optimizer/common/optimizer.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
class PassSwitchManager {
public:
static PassSwitchManager &GetInstance();
void RegistPass(const std::string &pass_name) { env_pass_switch_.emplace(pass_name, true); }
void RegistLicPass(const std::string &pass_name, enum OptPassEnum pass) { pass_enum_map_.emplace(pass_name, pass); }
bool GetPassSwitch(const std::string &pass_name) const;
private:
PassSwitchManager();
~PassSwitchManager() = default;
enum OptPassEnum GetPassEnum(const std::string &pass_name) const;
void SetSwitchFromEnv();
std::map<std::string, enum OptPassEnum> pass_enum_map_ = {};
std::map<std::string, bool> env_pass_switch_ = {};
bool env_switch_ = true;
};
class PassWithSwitch : public Pass {
public:
explicit PassWithSwitch(const std::string &name = "pass") : Pass(name) {
PassSwitchManager::GetInstance().RegistPass(name);
}
virtual ~PassWithSwitch() = default;
bool Run(const FuncGraphPtr &func_graph) override;
protected:
virtual bool RunPass(const FuncGraphPtr &func_graph) = 0;
};
class PatternProcessPassWithSwitch : public PatternProcessPass {
public:
explicit PatternProcessPassWithSwitch(const std::string &name = "", bool multigraph = true)
: PatternProcessPass(name, multigraph) {
PassSwitchManager::GetInstance().RegistPass(name);
}
~PatternProcessPassWithSwitch() override = default;
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_ASCEND_PASS_CONTROL_H_

View File

@ -39,9 +39,6 @@ void BatchMatmulFusedMulAddFusionPass::MatchBatchMatmulFusedMulAdd(const CNodePt
void BatchMatmulFusedMulAddFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BatchMatmulFusedMulAddFusionPass)) {
return;
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
MS_EXCEPTION_IF_NULL(node);

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class BatchMatmulFusedMulAddFusionPass : public FusionBasePass {
public:
explicit BatchMatmulFusedMulAddFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {}
: FusionBasePass("BatchMatmulFusedMulAddFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BatchMatmulFusedMulAddFusionPass);
}
~BatchMatmulFusedMulAddFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -54,9 +54,6 @@ void BnupdateEltwiseEltwiseFusionPass::MatchBnupdateAddRelu(const CNodePtr &cnod
void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BnupdateEltwiseEltwiseFusionPass)) {
return;
}
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass {
public:
explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {}
: FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BnupdateEltwiseEltwiseFusionPass);
}
~BnupdateEltwiseEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -43,9 +43,6 @@ void BnupdateEltwiseFusionPass::MatchBnupdateDoubleOutputEltwise(const CNodePtr
void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::BnupdateEltwiseFusionPass)) {
return;
}
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class BnupdateEltwiseFusionPass : public FusionBasePass {
public:
explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {}
: FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::BnupdateEltwiseFusionPass);
}
~BnupdateEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -65,9 +65,6 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltw
void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Conv2DBackpropEltwiseFusionPass)) {
return;
}
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass {
public:
explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {}
: FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Conv2DBackpropEltwiseFusionPass);
}
~Conv2DBackpropEltwiseEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -43,9 +43,6 @@ void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNod
void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Conv2DBackpropEltwiseFusionPass)) {
return;
}
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class Conv2DBackpropEltwiseFusionPass : public FusionBasePass {
public:
explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {}
: FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Conv2DBackpropEltwiseFusionPass);
}
~Conv2DBackpropEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -40,9 +40,6 @@ void ConvBnReduceFusionPass::MatchConvBnreduce(const CNodePtr &cnode, const sess
void ConvBnReduceFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ConvBnReduceFusionPass)) {
return;
}
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class ConvBnReduceFusionPass : public FusionBasePass {
public:
explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("ConvBnReduceFusionPass", idAllocator) {}
: FusionBasePass("ConvBnReduceFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ConvBnReduceFusionPass);
}
~ConvBnReduceFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -51,9 +51,6 @@ void EltwiseFusionPass::MatchEltwise(const CNodePtr &cnode, const session::Kerne
void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::EltwiseFusionPass)) {
return;
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {

View File

@ -33,7 +33,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class EltwiseFusionPass : public FusionBasePass {
public:
explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {}
explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::EltwiseFusionPass);
}
~EltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "backend/optimizer/ascend/buffer_fusion/fusion_base_pass.h"
#include <memory>
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "backend/session/anf_runtime_algorithm.h"
@ -105,7 +106,7 @@ bool FusionBasePass::MatchUBFusionPattern(const session::KernelGraph &kernel_gra
return true;
}
bool FusionBasePass::Run(const FuncGraphPtr &graph) {
bool FusionBasePass::RunPass(const FuncGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(graph);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
MS_EXCEPTION_IF_NULL(kernel_graph);

View File

@ -19,14 +19,14 @@
#include <unordered_set>
#include <vector>
#include <string>
#include <utility>
#include "ir/anf.h"
#include "backend/optimizer/common/pass.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
#include "runtime/device/kernel_info.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/session/kernel_graph.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -55,15 +55,15 @@ struct BufferFusionInfo_t {
kernel::KernelBuildInfoPtr kernel_build_info;
};
class FusionBasePass : public Pass {
class FusionBasePass : public PassWithSwitch {
public:
FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator)
: Pass(name), fusion_id_allocator(idAllocator) {}
: PassWithSwitch(name), fusion_id_allocator(std::move(idAllocator)) {}
~FusionBasePass() override = default;
bool Run(const FuncGraphPtr &graph) override;
bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph);
protected:
bool RunPass(const FuncGraphPtr &graph) override;
virtual void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) = 0;
void SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record);

View File

@ -19,7 +19,6 @@
#include "base/core_ops.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -41,11 +40,6 @@ void MatmulConfusionTranposeFusionPass::MatchMatmulConfusionTranpose(const CNode
void MatmulConfusionTranposeFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulConfusiontransposeUbFusion)) {
return;
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class MatmulConfusionTranposeFusionPass : public FusionBasePass {
public:
explicit MatmulConfusionTranposeFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {}
: FusionBasePass("MatmulConfusionTranposeFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulConfusiontransposeUbFusion);
}
~MatmulConfusionTranposeFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -39,10 +39,6 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulEltwiseFusionPass)) {
return;
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
for (auto &node : node_list) {
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) ||

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class MatmulEltwiseFusionPass : public FusionBasePass {
public:
explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {}
: FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulEltwiseFusionPass);
}
~MatmulEltwiseFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -19,7 +19,6 @@
#include "base/core_ops.h"
#include "utils/ms_context.h"
#include "backend/optimizer/common/fusion_id_allocator.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -57,9 +56,6 @@ void MultiOutputFusionPass::MatchMultiOutputEltwise(const CNodePtr &cnode, const
void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MultiOutputFusionPass)) {
return;
}
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
std::reverse(node_list.begin(), node_list.end());
for (auto &node : node_list) {

View File

@ -34,7 +34,9 @@ using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class MultiOutputFusionPass : public FusionBasePass {
public:
explicit MultiOutputFusionPass(FusionIdAllocatorPtr idAllocator)
: FusionBasePass("MultiOutputFusionPass", idAllocator) {}
: FusionBasePass("MultiOutputFusionPass", idAllocator) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MultiOutputFusionPass);
}
~MultiOutputFusionPass() override = default;
void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override;

View File

@ -514,7 +514,7 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionIn
return true;
}
bool UbPatternFusion::Run(const FuncGraphPtr &graph) {
bool UbPatternFusion::RunPass(const FuncGraphPtr &graph) {
bool changed = false;
MS_EXCEPTION_IF_NULL(graph);
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();

View File

@ -32,13 +32,14 @@ namespace mindspore {
namespace opt {
using FusedNodeRecord = std::vector<std::unordered_set<AnfNodePtr>>;
class UbPatternFusion : public Pass {
class UbPatternFusion : public PassWithSwitch {
public:
UbPatternFusion() : Pass("TbeBufferFusion") {}
UbPatternFusion() : PassWithSwitch("TbeBufferFusion") {}
~UbPatternFusion() override = default;
bool Run(const FuncGraphPtr &graph) override;
private:
bool RunPass(const FuncGraphPtr &graph) override;
void GetBufferFusionInfo(session::KernelGraph *kernel_graph,
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const;
bool ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos, int64_t fusion_id,

View File

@ -23,7 +23,6 @@
#include "base/core_ops.h"
#include "abstract/abstract_value.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -74,10 +73,6 @@ const AnfNodePtr BNReduceGradConv2dBackpropFilterFusion::Process(const FuncGraph
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::Resnet50DbnDwFusionPass)) {
return nullptr;
}
auto conv_back_filter = CheckAnfNodeIfCNodeAndInputSize(node, kConv2DBackpropFilterInputNum);
auto bnreduce_grad = CheckAnfNodeIfCNodeAndInputSize(conv_back_filter->input(kIndex1), kBNTrainingReduceGradInputNum);
if (!CheckSupported(conv_back_filter)) {

View File

@ -17,14 +17,16 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_BN_REDUCE_GRAD_CONV2D_BACKPROP_FILTER_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class BNReduceGradConv2dBackpropFilterFusion : public PatternProcessPass {
class BNReduceGradConv2dBackpropFilterFusion : public PatternProcessPassWithSwitch {
public:
explicit BNReduceGradConv2dBackpropFilterFusion(bool multigraph = true)
: PatternProcessPass("bn_reduce_grad_conv2d_backprop_filter_fusion", multigraph) {}
: PatternProcessPassWithSwitch("bn_reduce_grad_conv2d_backprop_filter_fusion", multigraph) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::Resnet50DbnDwFusionPass);
}
~BNReduceGradConv2dBackpropFilterFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -21,7 +21,6 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "ir/primitive.h"
#include "utils/utils.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -46,10 +45,6 @@ const AnfNodePtr ClipByNormNoDivSquareSumFusion::Process(const FuncGraphPtr &gra
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(equiv);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ClipByNormNoDivSquareSumFusion)) {
return nullptr;
}
BaseRef &input_gnode = (*equiv)[input_];
BaseRef &constant_select_gnode = (*equiv)[constant_select_];
BaseRef &constant_greater_gnode = (*equiv)[constant_greater_];

View File

@ -18,7 +18,7 @@
#include <vector>
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
@ -27,14 +27,15 @@ constexpr auto kConstantSelectVarName = "constant_select";
constexpr auto kConstantGreaterVarName = "constant_greater";
constexpr auto kConstantMaximumVarName = "constant_maximum";
class ClipByNormNoDivSquareSumFusion : public PatternProcessPass {
class ClipByNormNoDivSquareSumFusion : public PatternProcessPassWithSwitch {
public:
explicit ClipByNormNoDivSquareSumFusion(bool multigraph = true)
: PatternProcessPass("clip_by_norm_no_div_square_sum_fusion", multigraph) {
: PatternProcessPassWithSwitch("clip_by_norm_no_div_square_sum_fusion", multigraph) {
input_ = std::make_shared<Var>(kInputVarName);
constant_select_ = std::make_shared<Var>(kConstantSelectVarName);
constant_greater_ = std::make_shared<Var>(kConstantGreaterVarName);
constant_maximum_ = std::make_shared<Var>(kConstantMaximumVarName);
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ClipByNormNoDivSquareSumFusion);
}
~ClipByNormNoDivSquareSumFusion() override = default;
const BaseRef DefinePattern() const override;

View File

@ -22,7 +22,6 @@
#include "utils/utils.h"
#include "abstract/abstract_value.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {

View File

@ -20,7 +20,6 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "utils/trace_base.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {

View File

@ -19,7 +19,6 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "utils/trace_base.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -33,10 +32,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(graph);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MatmulBiasaddFusion)) {
return nullptr;
}
auto matmul = GetAnfNodeByVar(equiv, matmul_var_);
if (matmul == nullptr || !matmul->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Get CNode MatMul failed!"

View File

@ -17,17 +17,19 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MATMUL_BIASADD_FUSION_H_
#include <memory>
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class MatmulBiasaddFusion : public PatternProcessPass {
class MatmulBiasaddFusion : public PatternProcessPassWithSwitch {
public:
explicit MatmulBiasaddFusion(bool multigraph = true) : PatternProcessPass("matmul_biasadd_fusion", multigraph) {
explicit MatmulBiasaddFusion(bool multigraph = true)
: PatternProcessPassWithSwitch("matmul_biasadd_fusion", multigraph) {
x0_ = std::make_shared<Var>();
x1_ = std::make_shared<Var>();
x2_ = std::make_shared<Var>();
matmul_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimMatMul->name()));
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MatmulBiasaddFusion);
}
~MatmulBiasaddFusion() override = default;

View File

@ -19,7 +19,6 @@
#include <string>
#include "backend/optimizer/common/helper.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {

View File

@ -21,7 +21,6 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -72,10 +71,6 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP
return nullptr;
}
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MulAddFusion)) {
return nullptr;
}
CNodePtr mul = nullptr;
size_t mul_index = 0;
if (!GetMul(graph, add, &mul, &mul_index) || mul == nullptr || mul_index == 0) {

View File

@ -16,13 +16,15 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_MUL_ADD_FUSION_H
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class MulAddFusion : public PatternProcessPass {
class MulAddFusion : public PatternProcessPassWithSwitch {
public:
explicit MulAddFusion(bool multigraph = true) : PatternProcessPass("mul_add_fusion", multigraph) {}
explicit MulAddFusion(bool multigraph = true) : PatternProcessPassWithSwitch("mul_add_fusion", multigraph) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MulAddFusion);
}
~MulAddFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -16,11 +16,9 @@
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
#include <vector>
#include <memory>
#include <utility>
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/opt.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -60,10 +58,6 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
return nullptr;
}
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::MulAddNPass)) {
return nullptr;
}
auto addn = node->cast<CNodePtr>();
if (addn == nullptr) {
return nullptr;

View File

@ -16,13 +16,15 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_PASS_MUL_ADDN_FUSION_H
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class MulAddNFusion : public PatternProcessPass {
class MulAddNFusion : public PatternProcessPassWithSwitch {
public:
explicit MulAddNFusion(bool multigraph = true) : PatternProcessPass("mul_addn_fusion", multigraph) {}
explicit MulAddNFusion(bool multigraph = true) : PatternProcessPassWithSwitch("mul_addn_fusion", multigraph) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::MulAddNPass);
}
~MulAddNFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -20,7 +20,6 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "base/core_ops.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -68,10 +67,6 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
return nullptr;
}
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::ReshapeTransposeFusion)) {
return nullptr;
}
auto prim = std::make_shared<Primitive>(kConfusionTransposeDOpName);
std::vector<AnfNodePtr> inputs = {NewValueNode(prim), utils::cast<AnfNodePtr>((*equiv)[input_varptr_])};
auto new_node = func_graph->NewCNode(inputs);

View File

@ -24,14 +24,16 @@
#include "ir/anf.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class ReshapeTransposeFusion : public PatternProcessPass {
class ReshapeTransposeFusion : public PatternProcessPassWithSwitch {
public:
explicit ReshapeTransposeFusion(bool multigraph = true) : PatternProcessPass("reshape_transpose_fusion", multigraph) {
explicit ReshapeTransposeFusion(bool multigraph = true)
: PatternProcessPassWithSwitch("reshape_transpose_fusion", multigraph) {
input_varptr_ = std::make_shared<Var>();
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::ReshapeTransposeFusion);
}
~ReshapeTransposeFusion() override = default;
const BaseRef DefinePattern() const override;
@ -42,5 +44,4 @@ class ReshapeTransposeFusion : public PatternProcessPass {
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_RESHAPE_TRANSPOSE_FUSION_H_

View File

@ -19,7 +19,6 @@
#include "ir/primitive.h"
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {

View File

@ -26,7 +26,6 @@
#include "backend/optimizer/common/helper.h"
#include "runtime/device/kernel_info.h"
#include "utils/trace_base.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -100,10 +99,6 @@ const AnfNodePtr SquareSumFusion::Process(const FuncGraphPtr &graph, const AnfNo
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(node);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::SquareSumFusion)) {
return nullptr;
}
CNodePtr sum = nullptr;
AnfNodePtr square_anf = nullptr;
CNodePtr square = nullptr;

View File

@ -16,13 +16,15 @@
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_SQUARE_SUM_FUSION_H_
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class SquareSumFusion : public PatternProcessPass {
class SquareSumFusion : public PatternProcessPassWithSwitch {
public:
explicit SquareSumFusion(bool multigraph = true) : PatternProcessPass("square_sum_fusion", multigraph) {}
explicit SquareSumFusion(bool multigraph = true) : PatternProcessPassWithSwitch("square_sum_fusion", multigraph) {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::SquareSumFusion);
}
~SquareSumFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

View File

@ -20,7 +20,6 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "base/core_ops.h"
#include "runtime/device/ascend/lic_manager.h"
namespace mindspore {
namespace opt {
@ -50,10 +49,6 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
if (!LicManager::GetInstance().GetPassSwitch(OptPassEnum::TransposeReshapeFusion)) {
return nullptr;
}
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum);
MS_EXCEPTION_IF_NULL(reshape_cnode);
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum);

View File

@ -24,14 +24,16 @@
#include "ir/anf.h"
#include "backend/optimizer/common/pattern_engine.h"
#include "backend/optimizer/common/helper.h"
#include "backend/optimizer/common/optimizer.h"
#include "backend/optimizer/ascend/ascend_pass_control.h"
namespace mindspore {
namespace opt {
class TransposeReshapeFusion : public PatternProcessPass {
class TransposeReshapeFusion : public PatternProcessPassWithSwitch {
public:
explicit TransposeReshapeFusion(bool multigraph = true) : PatternProcessPass("transpose_reshape_fusion", multigraph) {
explicit TransposeReshapeFusion(bool multigraph = true)
: PatternProcessPassWithSwitch("transpose_reshape_fusion", multigraph) {
input_varptr_ = std::make_shared<Var>();
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::TransposeReshapeFusion);
}
~TransposeReshapeFusion() override = default;
const BaseRef DefinePattern() const override;

View File

@ -40,8 +40,7 @@ const BaseRef PatternProcessPass::DefinePattern() const {
void PatternProcessPass::Build() {
VarPtr fg = std::make_shared<Var>("RootG");
BaseRef pattern = std::move(DefinePattern());
pattern_ = SexpToNode(pattern, fg, primitive_vars_.get(), multigraph_);
pattern_ = SexpToNode(DefinePattern(), fg, primitive_vars_.get(), multigraph_);
}
AnfNodePtr PatternProcessPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {

View File

@ -33,7 +33,7 @@ class Pass {
explicit Pass(const std::string &name = "pass") : name_(name) {}
virtual ~Pass() = default;
virtual bool Run(const FuncGraphPtr &func_graph) = 0;
virtual std::string name() const { return name_; }
const std::string &name() const { return name_; }
void SetCacheManager(const CacheManagerPtr &cm) { cache_manager_ = cm; }
const CacheManagerPtr &GetCacheManager() const { return cache_manager_; }

View File

@ -45,6 +45,7 @@ enum class OptPassEnum {
Resnet50DbnDwFusionPass,
MatmulConfusiontransposeUbFusion,
TbeBatchMatmulElementWiseFusionPass,
Invalid,
};
class LicManager {

View File

@ -0,0 +1,301 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#define private public
#include "backend/optimizer/ascend/ascend_pass_control.h"
#undef private
namespace {
constexpr char kMsAscendFusionSwitch[] = "MS_ASCEND_FUSION_SWITCH";
} // namespace
namespace mindspore {
namespace opt {
class PlantPass : public Pass {
public:
PlantPass() : Pass("plant") {}
~PlantPass() override = default;
bool GetRunStatus() const { return is_run_; }
bool Run(const FuncGraphPtr &) override {
is_run_ = true;
return true;
}
bool is_run_ = false;
};
class PlantPatternPass : public PatternProcessPass {
public:
PlantPatternPass() : PatternProcessPass("plant_pattern") { is_run_ = false; }
~PlantPatternPass() override = default;
bool GetRunStatus() const { return is_run_; }
const BaseRef DefinePattern() const override { return BaseRef({std::make_shared<Var>()}); }
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override {
is_run_ = true;
return nullptr;
}
inline static bool is_run_ = false;
};
class TestPass : public PassWithSwitch {
public:
TestPass() : PassWithSwitch("test") {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::DereluFusion);
}
~TestPass() override = default;
bool GetRunStatus() const { return is_run_; }
bool RunPass(const FuncGraphPtr &) override {
is_run_ = true;
return true;
}
bool is_run_ = false;
};
class TestPatternPass : public PatternProcessPassWithSwitch {
public:
TestPatternPass() : PatternProcessPassWithSwitch("test_pattern") {
PassSwitchManager::GetInstance().RegistLicPass(name(), OptPassEnum::DereluFusion);
is_run_ = false;
}
~TestPatternPass() override = default;
bool GetRunStatus() const { return is_run_; }
const BaseRef DefinePattern() const override { return BaseRef({std::make_shared<Var>()}); }
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override {
is_run_ = true;
return nullptr;
}
inline static bool is_run_ = false;
};
class TestAscendPassControl : public UT::Common {
public:
TestAscendPassControl() = default;
void TearDown() override {
(void)unsetenv(kMsAscendFusionSwitch);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
}
};
/// Feature: Pass Switch
/// Description: no MS_ASCEND_FUSION_SWITCH set and run pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_no_env_for_pass) {
(void)unsetenv(kMsAscendFusionSwitch);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_TRUE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: set MS_ASCEND_FUSION_SWITCH as "on" and run pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_env_on_for_pass_0) {
(void)setenv(kMsAscendFusionSwitch, "on", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_TRUE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: set invalid MS_ASCEND_FUSION_SWITCH and run pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_env_on_for_pass_1) {
(void)setenv(kMsAscendFusionSwitch, "invalidxxxxxxxx", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_TRUE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: set MS_ASCEND_FUSION_SWITCH as "0" and run pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pass_0) {
(void)setenv(kMsAscendFusionSwitch, "0", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_FALSE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: set MS_ASCEND_FUSION_SWITCH as "off" and run pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pass_1) {
(void)setenv(kMsAscendFusionSwitch, "off", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_FALSE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: set MS_ASCEND_FUSION_SWITCH as "OFF" and run pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pass_2) {
(void)setenv(kMsAscendFusionSwitch, "OFF", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPass pass;
ASSERT_FALSE(pass.GetRunStatus());
pass.Run(nullptr);
ASSERT_FALSE(pass.GetRunStatus());
PlantPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(nullptr);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: no MS_ASCEND_FUSION_SWITCH set and run pattern pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_no_env_for_pattern_pass) {
(void)unsetenv(kMsAscendFusionSwitch);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_TRUE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: MS_ASCEND_FUSION_SWITCH set on and run pattern pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_env_on_for_pattern_pass) {
(void)setenv(kMsAscendFusionSwitch, "on", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_TRUE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: MS_ASCEND_FUSION_SWITCH set invalid and run pattern pass
/// Expectation: switch pass run, plant pass run
TEST_F(TestAscendPassControl, test_env_invalid_for_pattern_pass) {
(void)setenv(kMsAscendFusionSwitch, "invalid_xxasdasdasfsldjmg", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_TRUE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: MS_ASCEND_FUSION_SWITCH set off and run pattern pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_0) {
(void)setenv(kMsAscendFusionSwitch, "off", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_FALSE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: MS_ASCEND_FUSION_SWITCH set OFF and run pattern pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_1) {
(void)setenv(kMsAscendFusionSwitch, "OFF", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_FALSE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
/// Feature: Pass Switch
/// Description: MS_ASCEND_FUSION_SWITCH set 0 and run pattern pass
/// Expectation: switch pass dont run, plant pass run
TEST_F(TestAscendPassControl, test_env_off_for_pattern_pass_2) {
(void)setenv(kMsAscendFusionSwitch, "0", 1);
PassSwitchManager::GetInstance().SetSwitchFromEnv();
TestPatternPass pass;
ASSERT_FALSE(pass.GetRunStatus());
FuncGraphPtr graph = std::make_shared<FuncGraph>();
CNodePtr node = graph->NewCNode({NewValueNode(std::make_shared<Primitive>("test"))});
pass.Run(graph, node);
ASSERT_FALSE(pass.GetRunStatus());
PlantPatternPass plant_pass;
ASSERT_FALSE(plant_pass.GetRunStatus());
plant_pass.Run(graph, node);
ASSERT_TRUE(plant_pass.GetRunStatus());
}
} // namespace opt
} // namespace mindspore