Support index to switch_layer

This commit is contained in:
panyifeng 2020-05-11 10:30:13 +08:00
parent 08d86c483c
commit 065e25e1bb
6 changed files with 125 additions and 13 deletions

View File

@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod
return ret_graph;
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item
// args: tuple of items, index
const std::string op_name = std::string("TupleGetItemTensor");
abstract::CheckArgsSize(op_name, args_spec_list, 2);
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractBasePtrList branches = branches_abs->elements();
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
AnfNodePtr functions = ret_graph->add_parameter();
auto index = ret_graph->add_parameter();
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
return ret_graph;
}
MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
}
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
.def(py::init<std::string &>());
@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
*m, "TupleGetItemTensor_")
.def(py::init<std::string &>());
}));
} // namespace prim
} // namespace mindspore

View File

@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph {
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
};
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
class TupleGetItemTensor : public MetaFuncGraph {
public:
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
~TupleGetItemTensor() override = default;
MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
return lhs.name_ == rhs.name_;
}
};
using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
} // namespace prim
} // namespace mindspore

View File

@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: index, branch
if (args_spec_list.size() != 2) {
MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is "
<< args_spec_list.size() << ".";
}
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1);
const std::string op_name = primitive->name();
abstract::CheckArgsSize(op_name, args_spec_list, 2);
(void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
AbstractBasePtrList branches = branches_abs->elements();
const size_t maximum_layer_num = 1000;
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got "
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
<< branches.size() << " branches.";
}
MS_EXCEPTION_IF_NULL(branches[0]);
for (size_t i = 0; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
if (!branches[i]->isa<AbstractFunction>()) {
MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got "
<< branches[i]->ToString() << " as the " << i << "th element.";
}
}
auto b = branches[0];
for (size_t i = 1; i < branches.size(); i++) {
MS_EXCEPTION_IF_NULL(branches[i]);
b = b->Join(branches[i]);
}
return b;

View File

@ -18,13 +18,13 @@
"""Basic composite operations."""
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
from ...common import dtype as mstype
from ...common.api import ms_function
from .. import functional as F
from .. import operations as P
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_]
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
def add_flags(fn, **flags):

View File

@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice')
"""_tensor_slice is an metafuncgraph object which will slice a tensor."""
class _TupleGetItemTensor(base.TupleGetItemTensor_):
"""
Getting item of tuple by tensor index.
Inputs:
data (tuple): A tuple of items.
index (Tensor): The index in tensor.
Outputs:
Type, is same as the element type of data.
"""
def __init__(self, name):
base.TupleGetItemTensor_.__init__(self, name)
def __call__(self, *args):
pass
_tuple_get_item_tensor = _TupleGetItemTensor('tuple_get_item_tensor')
"""_tuple_get_item_tensor is an metafuncgraph object which will select indexed item."""
@getitem.register("Tuple", "Number")
def _tuple_getitem_by_number(data, number_index):
"""
@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index):
return _tuple_slice(data, slice_index)
@getitem.register("Tuple", "Tensor")
def _tuple_getitem_by_tensor(data, tensor_index):
"""
Getting item out of tuple by tensor index.
Inputs:
data (tuple): A tuple of items to index.
tensor_index (Tensor): Index to select item.
Outputs:
Type, is same as the element type of data.
"""
return _tuple_get_item_tensor(data, tensor_index)
@getitem.register("List", "Number")
def _list_getitem_by_number(data, number_index):
"""

View File

@ -387,7 +387,38 @@ def test_switch_layer():
ret = F.switch_layer(index, self.layers)(x) * self.z3
return ret
index = Tensor(0)
net = SwitchLayerCell()
net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_index_to_switch_layer():
class Layer1(nn.Cell):
def __init__(self):
super(Layer1, self).__init__()
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
def construct(self, x):
return x * self.z1
class Layer2(nn.Cell):
def __init__(self):
super(Layer2, self).__init__()
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
def construct(self, x):
return x * self.z2
class SwitchLayerCell(nn.Cell):
def __init__(self):
super(SwitchLayerCell, self).__init__()
self.layers = (Layer1(), Layer2())
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
def construct(self, index, x):
ret = self.layers[index](x) * self.z3
return ret
index = Tensor(0)
net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))