forked from mindspore-Ecosystem/mindspore
fix switch layer sigle prim cell
This commit is contained in:
parent
10015ad9b2
commit
1c296b96c9
|
@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
|
||||||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
||||||
|
|
||||||
|
if (*(func_graph->switch_layer_input())) {
|
||||||
|
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
|
||||||
|
}
|
||||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
|
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
|
||||||
<< func_graph->debug_info()->get_id() << "\n";
|
<< func_graph->debug_info()->get_id() << "\n";
|
||||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||||
|
|
|
@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
||||||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||||
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||||
}
|
}
|
||||||
|
// To keep switch_layer's inputs from being inlined
|
||||||
|
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
|
|
||||||
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
|
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
|
||||||
|
|
|
@ -45,6 +45,7 @@
|
||||||
#include "frontend/optimizer/opt.h"
|
#include "frontend/optimizer/opt.h"
|
||||||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||||
|
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
// Value_Based Eliminate
|
// Value_Based Eliminate
|
||||||
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
|
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
|
||||||
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
|
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
|
||||||
|
|
||||||
|
// switch_layer defer inline
|
||||||
|
switch_layer_defer_inline_ =
|
||||||
|
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResolveIRPassLib::ResolveIRPassLib() {
|
ResolveIRPassLib::ResolveIRPassLib() {
|
||||||
|
|
|
@ -113,6 +113,9 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// Value_Based Eliminate
|
// Value_Based Eliminate
|
||||||
SubstitutionPtr value_based_eliminate_;
|
SubstitutionPtr value_based_eliminate_;
|
||||||
|
|
||||||
|
// SwitchLayer defer inline
|
||||||
|
SubstitutionPtr switch_layer_defer_inline_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the collection of irpass for resolve action
|
// the collection of irpass for resolve action
|
||||||
|
|
|
@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
||||||
|
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "frontend/optimizer/irpass.h"
|
||||||
|
#include "frontend/optimizer/optimizer.h"
|
||||||
|
#include "frontend/optimizer/anf_visitor.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
// {prim::kPrimSwitchLayer, {Index, layers}}
|
||||||
|
class SwitchLayerDeferInline : public AnfVisitor {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
|
||||||
|
for (auto elem : tuple->elements()) {
|
||||||
|
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
|
||||||
|
*(abstract->func_graph()->switch_layer_input()) = true;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
|
@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
|
||||||
namespace {
|
namespace {
|
||||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||||
|
irpass.switch_layer_defer_inline_,
|
||||||
irpass.switch_simplify_,
|
irpass.switch_simplify_,
|
||||||
|
|
||||||
// Safe inlining
|
// Safe inlining
|
||||||
|
|
|
@ -48,6 +48,7 @@ FuncGraph::FuncGraph()
|
||||||
manager_(std::weak_ptr<FuncGraphManager>()),
|
manager_(std::weak_ptr<FuncGraphManager>()),
|
||||||
stub_(false) {
|
stub_(false) {
|
||||||
debug_info_ = std::make_shared<GraphDebugInfo>();
|
debug_info_ = std::make_shared<GraphDebugInfo>();
|
||||||
|
switch_layer_input_ = std::make_shared<bool>(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract::AbstractBasePtr FuncGraph::ToAbstract() {
|
abstract::AbstractBasePtr FuncGraph::ToAbstract() {
|
||||||
|
|
|
@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase {
|
||||||
bool stub() const { return stub_; }
|
bool stub() const { return stub_; }
|
||||||
void set_stub(bool stub) { stub_ = stub; }
|
void set_stub(bool stub) { stub_ = stub; }
|
||||||
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
|
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
|
||||||
|
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
|
||||||
|
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// graph is manipulated by manager and others
|
// graph is manipulated by manager and others
|
||||||
|
@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase {
|
||||||
std::list<CNodePtr> order_;
|
std::list<CNodePtr> order_;
|
||||||
bool stub_;
|
bool stub_;
|
||||||
inline static Drawer drawer_ = nullptr;
|
inline static Drawer drawer_ = nullptr;
|
||||||
|
// Design switch_layer_input as a ptr to
|
||||||
|
// share between derived backpropagator and cloned graphs
|
||||||
|
std::shared_ptr<bool> switch_layer_input_;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
||||||
|
|
|
@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
||||||
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
|
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
|
||||||
(*target_func_graph)->set_is_generate(func_graph->is_generated());
|
(*target_func_graph)->set_is_generate(func_graph->is_generated());
|
||||||
(*target_func_graph)->set_stub(func_graph->stub());
|
(*target_func_graph)->set_stub(func_graph->stub());
|
||||||
|
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
|
||||||
TraceManager::EndTrace();
|
TraceManager::EndTrace();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
||||||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
||||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||||
new_func_graph->set_stub(func_graph->stub());
|
new_func_graph->set_stub(func_graph->stub());
|
||||||
|
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
|
||||||
for (auto &item : func_graph->parameter_default_value()) {
|
for (auto &item : func_graph->parameter_default_value()) {
|
||||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -444,6 +444,26 @@ def test_index_to_switch_layer():
|
||||||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
|
||||||
|
|
||||||
|
def test_switch_layer_with_single_prim():
|
||||||
|
class SwitchLayerCell(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(SwitchLayerCell, self).__init__()
|
||||||
|
self.layers = (nn.ReLU(), nn.ReLU())
|
||||||
|
self.z3 = Parameter(
|
||||||
|
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||||
|
|
||||||
|
def construct(self, index, x):
|
||||||
|
ret = self.layers[index](x) * self.z3
|
||||||
|
return ret
|
||||||
|
|
||||||
|
index = Tensor(0, dtype=mstype.int32)
|
||||||
|
net = SwitchLayerCell()
|
||||||
|
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||||
|
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||||
|
|
||||||
|
|
||||||
def test_control_depend_check():
|
def test_control_depend_check():
|
||||||
with pytest.raises(TypeError) as e:
|
with pytest.raises(TypeError) as e:
|
||||||
P.ControlDepend(0.0)
|
P.ControlDepend(0.0)
|
||||||
|
|
Loading…
Reference in New Issue