!15459 Add switch_defer_inline pass

From: @huangbingjian
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-04-21 15:10:53 +08:00 committed by Gitee
commit d48151ab1e
10 changed files with 46 additions and 10 deletions

View File

@ -618,6 +618,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
std::vector<AnfNodePtr> parameters = func_graph->parameters();
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
if (*(func_graph->switch_input())) {
ofs << "switch_input: " << *(func_graph->switch_input()) << "\n";
}
if (*(func_graph->switch_layer_input())) {
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
}

View File

@ -45,7 +45,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
k_graph_ = std::make_shared<FuncGraph>();
}
// To keep switch_layer's inputs from being inlined
// To keep switch or switch_layer's inputs from being inlined
k_graph_->set_switch_input(primal_graph->switch_input());
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage());

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -47,7 +47,7 @@
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
namespace mindspore {
@ -231,6 +231,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
// switch defer inline
switch_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);
// switch_layer defer inline
switch_layer_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -139,6 +139,9 @@ class OptimizeIRPassLib {
// Value_Based Eliminate
SubstitutionPtr value_based_eliminate_;
// Switch defer inline
SubstitutionPtr switch_defer_inline_;
// SwitchLayer defer inline
SubstitutionPtr switch_layer_defer_inline_;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -41,7 +41,8 @@ class ReplaceApplicator : public AnfVisitor {
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) ||
*(fg->switch_layer_input())) {
return nullptr;
}

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,12 +28,29 @@
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitchLayer, {Index, layers}}
// {prim::kPrimSwitch, cond, true_branch, false_branch}
class SwitchDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract());
if (true_abstract != nullptr) {
*(true_abstract->func_graph()->switch_input()) = true;
}
auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract());
if (false_abstract != nullptr) {
*(false_abstract->func_graph()->switch_input()) = true;
}
return nullptr;
}
};
// {prim::kPrimSwitchLayer, Index, layers}
class SwitchLayerDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract());
for (auto elem : tuple->elements()) {
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
if (abstract != nullptr) {

View File

@ -97,6 +97,7 @@ bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { r
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({
irpass.switch_defer_inline_,
irpass.switch_layer_defer_inline_,
irpass.switch_simplify_,
irpass.exchange_switch_depend_value_,

View File

@ -50,6 +50,7 @@ FuncGraph::FuncGraph()
stub_(false),
stage_(-1) {
debug_info_ = std::make_shared<GraphDebugInfo>();
switch_input_ = std::make_shared<bool>(false);
switch_layer_input_ = std::make_shared<bool>(false);
}

View File

@ -381,6 +381,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; }
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_input() const { return switch_input_; }
void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
bool ContainMultiTarget() const;
@ -462,8 +464,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
OrderedSet<CNodePtr> order_;
bool stub_;
inline static Drawer drawer_ = nullptr;
// Design switch_layer_input as a ptr to
// Design switch_input and switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs.
std::shared_ptr<bool> switch_input_;
std::shared_ptr<bool> switch_layer_input_;
int64_t stage_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,

View File

@ -233,6 +233,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub());
(*target_func_graph)->set_switch_input(func_graph->switch_input());
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
}
@ -680,6 +681,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_input(func_graph->switch_input());
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]);