forked from mindspore-Ecosystem/mindspore
!15459 Add switch_defer_inline pass
From: @huangbingjian Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
d48151ab1e
|
@ -618,6 +618,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
|
|||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
||||
|
||||
if (*(func_graph->switch_input())) {
|
||||
ofs << "switch_input: " << *(func_graph->switch_input()) << "\n";
|
||||
}
|
||||
if (*(func_graph->switch_layer_input())) {
|
||||
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
|
||||
}
|
||||
|
|
|
@ -45,7 +45,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
TraceGuard guard(std::make_shared<TraceGradFprop>(primal_graph->debug_info()));
|
||||
k_graph_ = std::make_shared<FuncGraph>();
|
||||
}
|
||||
// To keep switch_layer's inputs from being inlined
|
||||
// To keep switch or switch_layer's inputs from being inlined
|
||||
k_graph_->set_switch_input(primal_graph->switch_input());
|
||||
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
||||
k_graph_->set_stage(primal_graph->stage());
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -47,7 +47,7 @@
|
|||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
|
||||
#include "frontend/optimizer/irpass/switch_or_switch_layer_defer_inline.h"
|
||||
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -231,6 +231,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
|
||||
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
|
||||
|
||||
// switch defer inline
|
||||
switch_defer_inline_ =
|
||||
MakeSubstitution(std::make_shared<SwitchDeferInline>(), "switch_defer_inline", prim::kPrimSwitch);
|
||||
|
||||
// switch_layer defer inline
|
||||
switch_layer_defer_inline_ =
|
||||
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -139,6 +139,9 @@ class OptimizeIRPassLib {
|
|||
// Value_Based Eliminate
|
||||
SubstitutionPtr value_based_eliminate_;
|
||||
|
||||
// Switch defer inline
|
||||
SubstitutionPtr switch_defer_inline_;
|
||||
|
||||
// SwitchLayer defer inline
|
||||
SubstitutionPtr switch_layer_defer_inline_;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -41,7 +41,8 @@ class ReplaceApplicator : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) {
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_input()) ||
|
||||
*(fg->switch_layer_input())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -28,12 +28,29 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimSwitchLayer, {Index, layers}}
|
||||
// {prim::kPrimSwitch, cond, true_branch, false_branch}
|
||||
class SwitchDeferInline : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto true_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(2)->abstract());
|
||||
if (true_abstract != nullptr) {
|
||||
*(true_abstract->func_graph()->switch_input()) = true;
|
||||
}
|
||||
auto false_abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(cnode->input(3)->abstract());
|
||||
if (false_abstract != nullptr) {
|
||||
*(false_abstract->func_graph()->switch_input()) = true;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
// {prim::kPrimSwitchLayer, Index, layers}
|
||||
class SwitchLayerDeferInline : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
|
||||
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->input(2)->abstract());
|
||||
for (auto elem : tuple->elements()) {
|
||||
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
|
||||
if (abstract != nullptr) {
|
|
@ -97,6 +97,7 @@ bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { r
|
|||
|
||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||
irpass.switch_defer_inline_,
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.switch_simplify_,
|
||||
irpass.exchange_switch_depend_value_,
|
||||
|
|
|
@ -50,6 +50,7 @@ FuncGraph::FuncGraph()
|
|||
stub_(false),
|
||||
stage_(-1) {
|
||||
debug_info_ = std::make_shared<GraphDebugInfo>();
|
||||
switch_input_ = std::make_shared<bool>(false);
|
||||
switch_layer_input_ = std::make_shared<bool>(false);
|
||||
}
|
||||
|
||||
|
|
|
@ -381,6 +381,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
bool stub() const { return stub_; }
|
||||
void set_stub(bool stub) { stub_ = stub; }
|
||||
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
|
||||
std::shared_ptr<bool> switch_input() const { return switch_input_; }
|
||||
void set_switch_input(std::shared_ptr<bool> switch_input) { switch_input_ = switch_input; }
|
||||
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
|
||||
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
|
||||
bool ContainMultiTarget() const;
|
||||
|
@ -462,8 +464,9 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder {
|
|||
OrderedSet<CNodePtr> order_;
|
||||
bool stub_;
|
||||
inline static Drawer drawer_ = nullptr;
|
||||
// Design switch_layer_input as a ptr to
|
||||
// Design switch_input and switch_layer_input as a ptr to
|
||||
// share between derived backpropagator and cloned graphs.
|
||||
std::shared_ptr<bool> switch_input_;
|
||||
std::shared_ptr<bool> switch_layer_input_;
|
||||
int64_t stage_;
|
||||
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
|
||||
|
|
|
@ -233,6 +233,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
(*target_func_graph)->set_is_generate(func_graph->is_generated());
|
||||
(*target_func_graph)->set_stub(func_graph->stub());
|
||||
(*target_func_graph)->set_switch_input(func_graph->switch_input());
|
||||
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
|
||||
}
|
||||
|
||||
|
@ -680,6 +681,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||
new_func_graph->set_stub(func_graph->stub());
|
||||
new_func_graph->set_switch_input(func_graph->switch_input());
|
||||
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
|
||||
for (auto &item : func_graph->parameter_default_value()) {
|
||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||
|
|
Loading…
Reference in New Issue