forked from mindspore-Ecosystem/mindspore
Support index to switch_layer
This commit is contained in:
parent
08d86c483c
commit
065e25e1bb
|
@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod
|
|||
return ret_graph;
|
||||
}
|
||||
|
||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// select indexed item
|
||||
// args: tuple of items, index
|
||||
const std::string op_name = std::string("TupleGetItemTensor");
|
||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
|
||||
if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) {
|
||||
FuncGraphPtr ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
AnfNodePtr functions = ret_graph->add_parameter();
|
||||
auto index = ret_graph->add_parameter();
|
||||
|
||||
ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions}));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << ".";
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_")
|
||||
.def(py::init<std::string &>());
|
||||
|
@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) {
|
|||
(void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
|
||||
REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
||||
(void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>(
|
||||
*m, "TupleGetItemTensor_")
|
||||
.def(py::init<std::string &>());
|
||||
}));
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph {
|
|||
FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const;
|
||||
};
|
||||
using TensorSlicePtr = std::shared_ptr<TensorSlice>;
|
||||
|
||||
class TupleGetItemTensor : public MetaFuncGraph {
|
||||
public:
|
||||
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~TupleGetItemTensor() override = default;
|
||||
MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) {
|
||||
return lhs.name_ == rhs.name_;
|
||||
}
|
||||
};
|
||||
using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &,
|
|||
AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: index, branch
|
||||
if (args_spec_list.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is "
|
||||
<< args_spec_list.size() << ".";
|
||||
}
|
||||
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1);
|
||||
const std::string op_name = primitive->name();
|
||||
abstract::CheckArgsSize(op_name, args_spec_list, 2);
|
||||
(void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
|
||||
AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||
AbstractBasePtrList branches = branches_abs->elements();
|
||||
const size_t maximum_layer_num = 1000;
|
||||
if (branches.size() < 0 || branches.size() > maximum_layer_num) {
|
||||
MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got "
|
||||
MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got "
|
||||
<< branches.size() << " branches.";
|
||||
}
|
||||
|
||||
MS_EXCEPTION_IF_NULL(branches[0]);
|
||||
for (size_t i = 0; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
if (!branches[i]->isa<AbstractFunction>()) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got "
|
||||
<< branches[i]->ToString() << " as the " << i << "th element.";
|
||||
}
|
||||
}
|
||||
|
||||
auto b = branches[0];
|
||||
for (size_t i = 1; i < branches.size(); i++) {
|
||||
MS_EXCEPTION_IF_NULL(branches[i]);
|
||||
b = b->Join(branches[i]);
|
||||
}
|
||||
return b;
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
"""Basic composite operations."""
|
||||
|
||||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function
|
||||
from .. import functional as F
|
||||
from .. import operations as P
|
||||
|
||||
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_]
|
||||
__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
||||
|
||||
|
||||
def add_flags(fn, **flags):
|
||||
|
|
|
@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice')
|
|||
"""_tensor_slice is an metafuncgraph object which will slice a tensor."""
|
||||
|
||||
|
||||
class _TupleGetItemTensor(base.TupleGetItemTensor_):
|
||||
"""
|
||||
Getting item of tuple by tensor index.
|
||||
|
||||
Inputs:
|
||||
data (tuple): A tuple of items.
|
||||
index (Tensor): The index in tensor.
|
||||
Outputs:
|
||||
Type, is same as the element type of data.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
base.TupleGetItemTensor_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_tuple_get_item_tensor = _TupleGetItemTensor('tuple_get_item_tensor')
|
||||
"""_tuple_get_item_tensor is an metafuncgraph object which will select indexed item."""
|
||||
|
||||
|
||||
@getitem.register("Tuple", "Number")
|
||||
def _tuple_getitem_by_number(data, number_index):
|
||||
"""
|
||||
|
@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index):
|
|||
return _tuple_slice(data, slice_index)
|
||||
|
||||
|
||||
@getitem.register("Tuple", "Tensor")
|
||||
def _tuple_getitem_by_tensor(data, tensor_index):
|
||||
"""
|
||||
Getting item out of tuple by tensor index.
|
||||
|
||||
Inputs:
|
||||
data (tuple): A tuple of items to index.
|
||||
tensor_index (Tensor): Index to select item.
|
||||
|
||||
Outputs:
|
||||
Type, is same as the element type of data.
|
||||
"""
|
||||
return _tuple_get_item_tensor(data, tensor_index)
|
||||
|
||||
|
||||
@getitem.register("List", "Number")
|
||||
def _list_getitem_by_number(data, number_index):
|
||||
"""
|
||||
|
|
|
@ -387,7 +387,38 @@ def test_switch_layer():
|
|||
ret = F.switch_layer(index, self.layers)(x) * self.z3
|
||||
return ret
|
||||
|
||||
index = Tensor(0)
|
||||
net = SwitchLayerCell()
|
||||
net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
||||
def test_index_to_switch_layer():
|
||||
class Layer1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer1, self).__init__()
|
||||
self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1')
|
||||
def construct(self, x):
|
||||
return x * self.z1
|
||||
|
||||
class Layer2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Layer2, self).__init__()
|
||||
self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2')
|
||||
def construct(self, x):
|
||||
return x * self.z2
|
||||
|
||||
class SwitchLayerCell(nn.Cell):
|
||||
def __init__(self):
|
||||
super(SwitchLayerCell, self).__init__()
|
||||
self.layers = (Layer1(), Layer2())
|
||||
self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3')
|
||||
def construct(self, index, x):
|
||||
ret = self.layers[index](x) * self.z3
|
||||
return ret
|
||||
|
||||
index = Tensor(0)
|
||||
net = SwitchLayerCell()
|
||||
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
|
||||
|
|
Loading…
Reference in New Issue