Add switch_case primitive

This commit is contained in:
panyifeng 2020-04-24 11:13:40 +08:00
parent 5a03bd8077
commit b7596e1f33
11 changed files with 138 additions and 2 deletions

View File

@ -59,6 +59,7 @@ const PrimitivePtr kPrimHasType = std::make_shared<Primitive>("hastype");
// Statements
const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");

View File

@ -65,6 +65,7 @@ extern const PrimitivePtr kPrimHasType;
// Statements
extern const PrimitivePtr kPrimSwitch;
extern const PrimitivePtr kPrimSwitchLayer;
extern const PrimitivePtr kPrimReturn;
extern const PrimitivePtr kPrimAssign;
extern const PrimitivePtr kPrimAssignAdd;

View File

@ -126,6 +126,30 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
MS_LOG(EXCEPTION) << "Invalid condition value for switch " << cond->ToString();
}
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: index, branch
if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is "
<< args_spec_list.size() << ".";
}
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1);
AbstractBasePtrList branches = branches_abs->elements();
const size_t maximum_layer_num = 1000;
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got "
<< branches.size() << " branches.";
}
MS_EXCEPTION_IF_NULL(branches[0]);
auto b = branches[0];
for (size_t i = 1; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
b = b->Join(branches[i]);
}
return b;
}
std::vector<ValuePtr> GetSupportedTargetValue() {
std::vector<ValuePtr> list = {kNone, MakeValue(false), MakeValue(true)};
return list;

View File

@ -38,6 +38,7 @@ namespace mindspore {
namespace ad {
std::unordered_map<FuncGraphPtr, DFunctorPtr> DFunctor::func_graph_to_functor_;
std::unordered_map<AnfNodePtr, AdjointPtr> DFunctor::anfnode_to_adjoin_definition_;
FuncGraphSet DFunctor::scope_;
DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBasePtr &resources)
: primal_graph_(primal_graph), resources_(resources), need_cut_(false), is_top_(false) {
@ -55,11 +56,15 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
void DFunctor::Init(const DFunctorPtr &functor, bool is_top) {
func_graph_to_functor_[primal_graph_] = functor;
is_top_ = is_top;
if (is_top) {
scope_ = primal_graph_->scope();
}
}
void DFunctor::Clear() {
func_graph_to_functor_.clear();
anfnode_to_adjoin_definition_.clear();
scope_.clear();
}
void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
@ -95,11 +100,48 @@ void DFunctor::BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din) {
fv_adjoint->second->AccumulateDout(dfv);
}
void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env) {
// Take switch_layer as a set of candidate functions.
auto input = cnode_morph->input(2);
if (!IsPrimitiveCNode(input, prim::kPrimMakeTuple)) {
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << input->ToString() << ".";
}
auto tuple_graphs = input->cast<CNodePtr>();
for (size_t i = 1; i < tuple_graphs->size(); ++i) {
auto graph = tuple_graphs->input(i);
if (!IsValueNode<FuncGraph>(graph)) {
MS_LOG(EXCEPTION) << "The 2th input of switch_layer expect a tuple of graphs, but got " << graph->ToString()
<< " as the " << i << "th element.";
}
auto func_graph = GetValueNode<FuncGraphPtr>(graph);
auto functor = func_graph_to_functor_.find(func_graph);
if (functor == func_graph_to_functor_.end()) {
MS_LOG(EXCEPTION) << "BackPropagateSwitchLayer failed functor for subgraph does not exist input[" << i << "] "
<< func_graph->ToString() << ".";
}
// Consider direct and indirect fvs.
for (auto fv : func_graph->free_variables_nodes()) {
BackPropagateFv(fv, env);
}
for (auto indirect_fv : functor->second->anfnode_to_adjoin_indirect_fv_) {
MS_LOG(DEBUG) << "BackPropagateSwitchLayer backprop indirect fv " << func_graph->ToString() << " "
<< indirect_fv.first->ToString() << ".";
BackPropagateFv(indirect_fv.first, env);
}
}
}
void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) {
auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(1)});
// Call with delimited continuation dout.
auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()});
node_adjoint->RegisterDoutUser(bprop_app, 1);
// Special case for switch_layer
if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(0)});
BackPropagateSwitchLayer(cnode_morph, din);
return;
}
for (size_t i = 0; i < cnode_morph->size(); i++) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), bprop_app, NewValueNode(SizeToInt(i))});
auto input = cnode_morph->input(i);
@ -402,6 +444,11 @@ AnfNodePtr DFunctor::MapToK(const AnfNodePtr &primal) {
return primal;
}
bool DFunctor::IsInScope(const AnfNodePtr &node) {
return std::any_of(scope_.begin(), scope_.end(),
[&](const FuncGraphPtr &graph) { return node->func_graph() == graph; });
}
void DFunctor::MapFvObject() {
// Map free variable.
const auto &free_variables_nodes = primal_graph_->free_variables_nodes();
@ -414,8 +461,8 @@ void DFunctor::MapFvObject() {
if (parent_adjoint != nullptr) {
adjoint = std::make_shared<Adjoint>(node, parent_adjoint->k(), tape_);
} else {
if (is_top_) {
// Top graph for ad, add adjoint for free variables.
if (is_top_ || node->isa<Parameter>() || !IsInScope(node)) {
// Out of ad scope, add adjoint for free variables.
adjoint = std::make_shared<Adjoint>(node, node, tape_);
UpdateAdjoint(adjoint);
} else {

View File

@ -62,9 +62,11 @@ class DFunctor {
// Map one morphism.
AdjointPtr MapMorphism(const AnfNodePtr &morph);
bool IsFreeMorphism(const AnfNodePtr &node);
bool IsInScope(const AnfNodePtr &node);
// Map morphism that's not attached to output.
void MapFreeMorphism();
void BackPropagateFv(const AnfNodePtr &fv, const AnfNodePtr &din);
void BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNodePtr &env);
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
@ -101,6 +103,7 @@ class DFunctor {
bool is_top_;
static std::unordered_map<FuncGraphPtr, std::shared_ptr<DFunctor>> func_graph_to_functor_;
static std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_definition_;
static FuncGraphSet scope_;
};
// D Functor's rules to map primitive object.
@ -120,6 +123,7 @@ class KPrim {
private:
FuncGraphPtr GetBprop(const PrimitivePtr &prim);
FuncGraphPtr GetFprop(const PrimitivePtr &prim);
FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
// Given a bprop rule, do the K mapping.
template <typename T>

View File

@ -62,6 +62,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
return func_graph;
}
FuncGraphPtr KPrim::GetFprop(const PrimitivePtr &prim) {
static const std::string ad_module = "mindspore.ops._grad.grad_implementations";
std::string func_name = "_fprop_" + prim->name();
py::function fn = parse::python_adapter::GetPyFn(ad_module, func_name);
auto func_graph = parse::ParsePythonCode(fn);
MS_EXCEPTION_IF_NULL(func_graph);
return BasicClone(func_graph);
}
MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
@ -92,6 +101,13 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
return iter->second;
}
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
auto fprop = GetFprop(prim);
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
return fprop;
}
if (prim->name() == "make_tuple") {
return nullptr;
}

View File

@ -50,6 +50,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimHasType, {InferImplHasType, false}},
{prim::kPrimDot, {InferImplDot, true}},
{prim::kPrimSwitch, {InferImplSwitch, true}},
{prim::kPrimSwitchLayer, {InferImplSwitchLayer, true}},
{prim::kPrimIs_, {InferImplIs_, true}},
{prim::kPrimIsNot, {InferImplIsNot, true}},
{prim::kPrimInDict, {InferImplInDict, true}},

View File

@ -174,6 +174,8 @@ AbstractBasePtr InferImplDot(const AnalysisEnginePtr &, const PrimitivePtr &prim
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIs_(const AnalysisEnginePtr &, const PrimitivePtr &,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsNot(const AnalysisEnginePtr &, const PrimitivePtr &,

View File

@ -242,3 +242,9 @@ def bprop_switch(cond, tb, fb, out, dout):
"""Backpropagator for primitive `switch`."""
return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \
F.switch(cond, C.zeros_like(fb), dout)
def _fprop_switch_layer(index, layers):
"""Backpropagator for primitive `switch_layer`."""
def _bprop_switch_layer(dout):
return dout, C.zeros_like(index), ()
return F.switch_layer(index, layers), _bprop_switch_layer

View File

@ -135,6 +135,7 @@ env_getitem = Primitive('env_getitem')
env_add = Primitive('env_add')
J = Primitive('J')
switch = Primitive('switch')
switch_layer = Primitive('switch_layer')
# for sum bprop
reduced_shape = Primitive("reduced_shape")
# shape_mul:input mush be shape multiply elemts in tuple(shape)

View File

@ -19,6 +19,9 @@ from mindspore import nn
from mindspore import Tensor
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter, ParameterTuple
context.set_context(mode=context.GRAPH_MODE)
@ -358,3 +361,33 @@ def test_if_compile_true():
def test_if_compile_false():
output = if_compile_test(8, 3)
print("test_if_compile_false:", output)
def test_switch_layer():
class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
def construct(self, x):
return x * self.z1
class Layer2(nn.Cell):
def __init__(self):
super(Layer2, self).__init__()
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
def construct(self, x):
return x * self.z2
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (Layer1(), Layer2())
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = F.switch_layer(index, self.layers)(x) * self.z3
return ret
net = SwitchLayerCell()
net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))