forked from mindspore-Ecosystem/mindspore
fix switch layer sigle prim cell
This commit is contained in:
parent
10015ad9b2
commit
1c296b96c9
|
@ -607,6 +607,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun
|
|||
std::vector<AnfNodePtr> parameters = func_graph->parameters();
|
||||
OrderedMap<AnfNodePtr, int, ParamPtrHasher, ParamPtrEqual> param_map;
|
||||
|
||||
if (*(func_graph->switch_layer_input())) {
|
||||
ofs << "switch_layer_input: " << *(func_graph->switch_layer_input()) << "\n";
|
||||
}
|
||||
ofs << "# [No." << (exported.size() + 1) << "] " << func_graph->DumpText() << "."
|
||||
<< func_graph->debug_info()->get_id() << "\n";
|
||||
if (label_manage::GetGlobalTraceLabelType() == label_manage::TraceLabelType::kWithUniqueId) {
|
||||
|
|
|
@ -49,6 +49,8 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
|
||||
k_graph_->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(grad_op_name));
|
||||
}
|
||||
// To keep switch_layer's inputs from being inlined
|
||||
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
|
||||
TraceManager::EndTrace();
|
||||
|
||||
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
|
||||
|
|
|
@ -45,6 +45,7 @@
|
|||
#include "frontend/optimizer/opt.h"
|
||||
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
|
||||
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
|
@ -170,6 +171,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
// Value_Based Eliminate
|
||||
value_based_eliminate_ = MakeSubstitution(std::make_shared<ValueBasedEliminate>(), "value_based_eliminate",
|
||||
{prim::kPrimSelect, prim::kPrimMinimum, prim::kPrimMaximum});
|
||||
|
||||
// switch_layer defer inline
|
||||
switch_layer_defer_inline_ =
|
||||
MakeSubstitution(std::make_shared<SwitchLayerDeferInline>(), "switch_layer_defer_inline", prim::kPrimSwitchLayer);
|
||||
}
|
||||
|
||||
ResolveIRPassLib::ResolveIRPassLib() {
|
||||
|
|
|
@ -113,6 +113,9 @@ class OptimizeIRPassLib {
|
|||
|
||||
// Value_Based Eliminate
|
||||
SubstitutionPtr value_based_eliminate_;
|
||||
|
||||
// SwitchLayer defer inline
|
||||
SubstitutionPtr switch_layer_defer_inline_;
|
||||
};
|
||||
|
||||
// the collection of irpass for resolve action
|
||||
|
|
|
@ -39,7 +39,7 @@ class ReplaceApplicator : public AnfVisitor {
|
|||
}
|
||||
|
||||
auto fg = GetValueNode<FuncGraphPtr>(node);
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
|
||||
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
||||
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#include "frontend/optimizer/irpass.h"
|
||||
#include "frontend/optimizer/optimizer.h"
|
||||
#include "frontend/optimizer/anf_visitor.h"
|
||||
#include "frontend/operator/ops.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// {prim::kPrimSwitchLayer, {Index, layers}}
|
||||
class SwitchLayerDeferInline : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
auto tuple = dyn_cast<abstract::AbstractTuple>(cnode->inputs()[2]->abstract());
|
||||
for (auto elem : tuple->elements()) {
|
||||
auto abstract = dyn_cast<abstract::FuncGraphAbstractClosure>(elem);
|
||||
*(abstract->func_graph()->switch_layer_input()) = true;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_SWITCH_LAYER_DEFER_INLINE_H_
|
|
@ -90,6 +90,7 @@ bool CleanAfterOptAPass(const ResourcePtr &res) {
|
|||
namespace {
|
||||
OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||
opt::OptPassConfig a_1 = opt::OptPassConfig({
|
||||
irpass.switch_layer_defer_inline_,
|
||||
irpass.switch_simplify_,
|
||||
|
||||
// Safe inlining
|
||||
|
|
|
@ -48,6 +48,7 @@ FuncGraph::FuncGraph()
|
|||
manager_(std::weak_ptr<FuncGraphManager>()),
|
||||
stub_(false) {
|
||||
debug_info_ = std::make_shared<GraphDebugInfo>();
|
||||
switch_layer_input_ = std::make_shared<bool>(false);
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr FuncGraph::ToAbstract() {
|
||||
|
|
|
@ -353,6 +353,8 @@ class FuncGraph : public FuncGraphBase {
|
|||
bool stub() const { return stub_; }
|
||||
void set_stub(bool stub) { stub_ = stub; }
|
||||
static void set_drawer(Drawer drawer) { drawer_ = drawer; }
|
||||
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
|
||||
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
|
||||
|
||||
private:
|
||||
// graph is manipulated by manager and others
|
||||
|
@ -414,6 +416,9 @@ class FuncGraph : public FuncGraphBase {
|
|||
std::list<CNodePtr> order_;
|
||||
bool stub_;
|
||||
inline static Drawer drawer_ = nullptr;
|
||||
// Design switch_layer_input as a ptr to
|
||||
// share between derived backpropagator and cloned graphs
|
||||
std::shared_ptr<bool> switch_layer_input_;
|
||||
};
|
||||
|
||||
inline CNodePtr NewCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &fg) {
|
||||
|
|
|
@ -228,6 +228,7 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *cons
|
|||
(*target_func_graph)->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
(*target_func_graph)->set_is_generate(func_graph->is_generated());
|
||||
(*target_func_graph)->set_stub(func_graph->stub());
|
||||
(*target_func_graph)->set_switch_layer_input(func_graph->switch_layer_input());
|
||||
TraceManager::EndTrace();
|
||||
}
|
||||
|
||||
|
@ -645,6 +646,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
|
|||
new_func_graph->set_hyper_param_count(func_graph->hyper_param_count());
|
||||
new_func_graph->set_is_generate(func_graph->is_generated());
|
||||
new_func_graph->set_stub(func_graph->stub());
|
||||
new_func_graph->set_switch_layer_input(func_graph->switch_layer_input());
|
||||
for (auto &item : func_graph->parameter_default_value()) {
|
||||
new_func_graph->set_param_default_value(item.first, cloner[item.second]);
|
||||
}
|
||||
|
|
|
@ -444,6 +444,26 @@ def test_index_to_switch_layer():
|
|||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
||||
|
||||
def test_switch_layer_with_single_prim():
|
||||
class SwitchLayerCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SwitchLayerCell, self).__init__()
|
||||
self.layers = (nn.ReLU(), nn.ReLU())
|
||||
self.z3 = Parameter(
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
|
||||
def construct(self, index, x):
|
||||
ret = self.layers[index](x) * self.z3
|
||||
return ret
|
||||
|
||||
index = Tensor(0, dtype=mstype.int32)
|
||||
net = SwitchLayerCell()
|
||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
|
||||
Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
||||
|
||||
def test_control_depend_check():
|
||||
with pytest.raises(TypeError) as e:
|
||||
P.ControlDepend(0.0)
|
||||
|
|
Loading…
Reference in New Issue