fix switch layer sigle prim cell

This commit is contained in:
panyifeng 2020-08-10 11:44:53 +08:00
parent 10015ad9b2
commit 1c296b96c9
11 changed files with 90 additions and 1 deletions

View File

@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
std::vector<AnfNodePtr> parameters = func_graph->parameters(); std::vector<AnfNodePtr> parameters = func_graph->parameters();
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map; OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
if (*(func_graph->switch_layer_input())) {
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
}
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "." ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
<< func_graph->debug_info()->get_id() << "\n"; << func_graph->debug_info()->get_id() << "\n";
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) { if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {

View File

@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)); std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name)); k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
} }
// To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
TraceManager::EndTrace(); TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info())); TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));

View File

@ -45,6 +45,7 @@
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/row_tensor_eliminate.h" #include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h" #include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Value_Based Eliminate // Value_Based Eliminate
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate", value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum}); {prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
// switch_layer defer inline
switch_layer_defer_inline_ =
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

View File

@ -113,6 +113,9 @@ class OptimizeIRPassLib {
// Value_Based Eliminate // Value_Based Eliminate
SubstitutionPtr value_based_eliminate_; SubstitutionPtr value_based_eliminate_;
// SwitchLayer defer inline
SubstitutionPtr switch_layer_defer_inline_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action

View File

@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
} }
auto fg = GetValueNode<FuncGraphPtr>(node); auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) { if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
return nullptr; return nullptr;
} }

View File

@ -0,0 +1,47 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
#include <vector>
#include <algorithm>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSwitchLayer, {Index, layers}}
class SwitchLayerDeferInline : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
auto cnode = node->cast<CNodePtr>();
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
for (auto elem : tuple->elements()) {
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
*(abstract->func_graph()->switch_layer_input()) = true;
}
return nullptr;
}
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_

View File

@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
namespace { namespace {
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = opt::OptPassConfig({ opt::OptPassConfig a_1 = opt::OptPassConfig({
irpass.switch_layer_defer_inline_,
irpass.switch_simplify_, irpass.switch_simplify_,
// Safe inlining // Safe inlining

View File

@ -48,6 +48,7 @@ FuncGraph::FuncGraph()
manager_(std::weak_ptr<FuncGraphManager>()), manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) { stub_(false) {
debug_info_ = std::make_shared<GraphDebugInfo>(); debug_info_ = std::make_shared<GraphDebugInfo>();
switch_layer_input_ = std::make_shared<bool>(false);
} }
abstract::AbstractBasePtr FuncGraph::ToAbstract() { abstract::AbstractBasePtr FuncGraph::ToAbstract() {

View File

@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase {
bool stub() const { return stub_; } bool stub() const { return stub_; }
void set_stub(bool stub) { stub_ = stub; } void set_stub(bool stub) { stub_ = stub; }
static void set_drawer(Drawer drawer) { drawer_ = drawer; } static void set_drawer(Drawer drawer) { drawer_ = drawer; }
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
private: private:
// graph is manipulated by manager and others // graph is manipulated by manager and others
@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase {
std::list<CNodePtr> order_; std::list<CNodePtr> order_;
bool stub_; bool stub_;
inline static Drawer drawer_ = nullptr; inline static Drawer drawer_ = nullptr;
// Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_;
}; };
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) { inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {

View File

@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count()); (*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
(*target_func_graph)->set_is_generate(func_graph->is_generated()); (*target_func_graph)->set_is_generate(func_graph->is_generated());
(*target_func_graph)->set_stub(func_graph->stub()); (*target_func_graph)->set_stub(func_graph->stub());
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
TraceManager::EndTrace(); TraceManager::EndTrace();
} }
@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
new_func_graph->set_is_generate(func_graph->is_generated()); new_func_graph->set_is_generate(func_graph->is_generated());
new_func_graph->set_stub(func_graph->stub()); new_func_graph->set_stub(func_graph->stub());
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
for (auto &item : func_graph->parameter_default_value()) { for (auto &item : func_graph->parameter_default_value()) {
new_func_graph->set_param_default_value(item.first, cloner[item.second]); new_func_graph->set_param_default_value(item.first, cloner[item.second]);
} }

View File

@ -444,6 +444,26 @@ def test_index_to_switch_layer():
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_switch_layer_with_single_prim():
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (nn.ReLU(), nn.ReLU())
self.z3 = Parameter(
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = self.layers[index](x) * self.z3
return ret
index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_control_depend_check(): def test_control_depend_check():
with pytest.raises(TypeError) as e: with pytest.raises(TypeError) as e:
P.ControlDepend(0.0) P.ControlDepend(0.0)