forked from mindspore-Ecosystem/mindspore
Add switch_case primitive
This commit is contained in:
parent
5a03bd8077
commit
b7596e1f33
|
@ -59,6 +59,7 @@ const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
|
|||
|
||||
// Statements
|
||||
const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
|
||||
const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
|
||||
const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
|
||||
const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
|
||||
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
|
||||
|
|
|
@ -65,6 +65,7 @@ extern const PrimitivePtr kPrimHasType;
|
|||
|
||||
// Statements
|
||||
extern const PrimitivePtr kPrimSwitch;
|
||||
extern const PrimitivePtr kPrimSwitchLayer;
|
||||
extern const PrimitivePtr kPrimReturn;
|
||||
extern const PrimitivePtr kPrimAssign;
|
||||
extern const PrimitivePtr kPrimAssignAdd;
|
||||
|
|
|
@ -126,6 +126,30 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: index, branch
|
||||
if (args_spec_list.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
const size_t maximum_layer_num = 1000;
|
||||
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
|
||||
MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got "
|
||||
<< branches.size() << " branches.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(branches[0]);
|
||||
auto b = branches[0];
|
||||
for (size_t i = 1; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
b = b->Join(branches[i]);
|
||||
}
|
||||
return b;
|
||||
}
|
||||
|
||||
std::vector<ValuePtr> GetSupportedTargetValue() {
|
||||
std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
|
||||
return list;
|
||||
|
|
|
@ -38,6 +38,7 @@ namespace mindspore {
|
|||
namespace ad {
|
||||
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
|
||||
FuncGraphSet DFunctor::scope_;
|
||||
|
||||
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
|
||||
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
|
||||
|
@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
|
|||
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
|
||||
func_graph_to_functor_[primal_graph_] = functor;
|
||||
is_top_ = is_top;
|
||||
if (is_top) {
|
||||
scope_ = primal_graph_->scope();
|
||||
}
|
||||
}
|
||||
|
||||
void DFunctor::Clear() {
|
||||
func_graph_to_functor_.clear();
|
||||
anfnode_to_adjoin_definition_.clear();
|
||||
scope_.clear();
|
||||
}
|
||||
|
||||
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
||||
|
@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
|
|||
fv_adjoint->second->AccumulateDout(dfv);
|
||||
}
|
||||
|
||||
void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
|
||||
// Take switch_layer as a set of candidate functions.
|
||||
auto input = cnode_morph->input(2);
|
||||
if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
|
||||
}
|
||||
auto tuple_graphs = input->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < tuple_graphs->size(); ++i) {
|
||||
auto graph = tuple_graphs->input(i);
|
||||
if (!IsValueNode<FuncGraph>(graph)) {
|
||||
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
|
||||
<< " as the " << i << "th element.";
|
||||
}
|
||||
auto func_graph = GetValueNode<FuncGraphPtr>(graph);
|
||||
auto functor = func_graph_to_functor_.find(func_graph);
|
||||
if (functor == func_graph_to_functor_.end()) {
|
||||
MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
|
||||
<< func_graph->ToString() << ".";
|
||||
}
|
||||
// Consider direct and indirect fvs.
|
||||
for (auto fv : func_graph->free_variables_nodes()) {
|
||||
BackPropagateFv(fv, env);
|
||||
}
|
||||
for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
|
||||
MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
|
||||
<< indirect_fv.first->ToString() << ".";
|
||||
BackPropagateFv(indirect_fv.first, env);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
|
||||
auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
|
||||
// Call with delimited continuation dout.
|
||||
auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
|
||||
node_adjoint->RegisterDoutUser(bprop_app, 1);
|
||||
// Special case for switch_layer
|
||||
if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
|
||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)});
|
||||
BackPropagateSwitchLayer(cnode_morph, din);
|
||||
return;
|
||||
}
|
||||
for (size_t i = 0; i < cnode_morph->size(); i++) {
|
||||
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
|
||||
auto input = cnode_morph->input(i);
|
||||
|
@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
|
|||
return primal;
|
||||
}
|
||||
|
||||
bool DFunctor::IsInScope(const AnfNodePtr &node) {
|
||||
return std::any_of(scope_.begin(), scope_.end(),
|
||||
[&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
|
||||
}
|
||||
|
||||
void DFunctor::MapFvObject() {
|
||||
// Map free variable.
|
||||
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
|
||||
|
@ -414,8 +461,8 @@ void DFunctor::MapFvObject() {
|
|||
if (parent_adjoint != nullptr) {
|
||||
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
|
||||
} else {
|
||||
if (is_top_) {
|
||||
// Top graph for ad, add adjoint for free variables.
|
||||
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
|
||||
// Out of ad scope, add adjoint for free variables.
|
||||
adjoint = std::make_shared<Adjoint>(node, node, tape_);
|
||||
UpdateAdjoint(adjoint);
|
||||
} else {
|
||||
|
|
|
@ -62,9 +62,11 @@ class DFunctor {
|
|||
// Map one morphism.
|
||||
AdjointPtr MapMorphism(const AnfNodePtr &morph);
|
||||
bool IsFreeMorphism(const AnfNodePtr &node);
|
||||
bool IsInScope(const AnfNodePtr &node);
|
||||
// Map morphism that's not attached to output.
|
||||
void MapFreeMorphism();
|
||||
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
|
||||
void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
|
||||
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
|
||||
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
|
||||
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
|
||||
|
@ -101,6 +103,7 @@ class DFunctor {
|
|||
bool is_top_;
|
||||
static std::unordered_map<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
|
||||
static std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
|
||||
static FuncGraphSet scope_;
|
||||
};
|
||||
|
||||
// D Functor's rules to map primitive object.
|
||||
|
@ -120,6 +123,7 @@ class KPrim {
|
|||
|
||||
private:
|
||||
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
|
||||
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
||||
// Given a bprop rule, do the K mapping.
|
||||
template <typename T>
|
||||
|
|
|
@ -62,6 +62,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
|
|||
return func_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
|
||||
static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
|
||||
std::string func_name = "_fprop_" + prim->name();
|
||||
py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
|
||||
auto func_graph = parse::ParsePythonCode(fn);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
return BasicClone(func_graph);
|
||||
}
|
||||
|
||||
MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
|
||||
|
@ -92,6 +101,13 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
|
||||
auto fprop = GetFprop(prim);
|
||||
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
||||
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
|
||||
return fprop;
|
||||
}
|
||||
|
||||
if (prim->name() == "make_tuple") {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimHasType, {InferImplHasType, false}},
|
||||
{prim::kPrimDot, {InferImplDot, true}},
|
||||
{prim::kPrimSwitch, {InferImplSwitch, true}},
|
||||
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
|
||||
{prim::kPrimIs_, {InferImplIs_, true}},
|
||||
{prim::kPrimIsNot, {InferImplIsNot, true}},
|
||||
{prim::kPrimInDict, {InferImplInDict, true}},
|
||||
|
|
|
@ -174,6 +174,8 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||
|
|
|
@ -242,3 +242,9 @@ def bprop_switch(cond, tb, fb, out, dout):
|
|||
"""Backpropagator for primitive `switch`."""
|
||||
return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
|
||||
F.switch(cond, C.zeros_like(fb), dout)
|
||||
|
||||
def _fprop_switch_layer(index, layers):
|
||||
"""Backpropagator for primitive `switch_layer`."""
|
||||
def _bprop_switch_layer(dout):
|
||||
return dout, C.zeros_like(index), ()
|
||||
return F.switch_layer(index, layers), _bprop_switch_layer
|
||||
|
|
|
@ -135,6 +135,7 @@ env_getitem = Primitive('env_getitem')
|
|||
env_add = Primitive('env_add')
|
||||
J = Primitive('J')
|
||||
switch = Primitive('switch')
|
||||
switch_layer = Primitive('switch_layer')
|
||||
# for sum bprop
|
||||
reduced_shape = Primitive("reduced_shape")
|
||||
# shape_mul:input mush be shape multiply elemts in tuple(shape)
|
||||
|
|
|
@ -19,6 +19,9 @@ from mindspore import nn
|
|||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter, ParameterTuple
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
@ -358,3 +361,33 @@ def test_if_compile_true():
|
|||
def test_if_compile_false():
|
||||
output = if_compile_test(8, 3)
|
||||
print("test_if_compile_false:", output)
|
||||
|
||||
|
||||
def test_switch_layer():
|
||||
class Layer1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer1, self).__init__()
|
||||
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
def construct(self, x):
|
||||
return x * self.z1
|
||||
|
||||
class Layer2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer2, self).__init__()
|
||||
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
def construct(self, x):
|
||||
return x * self.z2
|
||||
|
||||
class SwitchLayerCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SwitchLayerCell, self).__init__()
|
||||
self.layers = (Layer1(), Layer2())
|
||||
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
def construct(self, index, x):
|
||||
ret = F.switch_layer(index, self.layers)(x) * self.z3
|
||||
return ret
|
||||
|
||||
net = SwitchLayerCell()
|
||||
net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
|
Loading…
Reference in New Issue