From 065e25e1bb6d1d97efb3d2acbf33c2702cc786d5 Mon Sep 17 00:00:00 2001 From: panyifeng Date: Mon, 11 May 2020 10:30:13 +0800 Subject: [PATCH] Support index to switch_layer --- .../ccsrc/operator/composite/composite.cc | 27 ++++++++++++++ .../ccsrc/operator/composite/composite.h | 12 ++++++ mindspore/ccsrc/operator/prim_statement.cc | 21 +++++++---- mindspore/ops/composite/base.py | 4 +- .../composite/multitype_ops/getitem_impl.py | 37 +++++++++++++++++++ tests/ut/python/ops/test_control_ops.py | 37 +++++++++++++++++-- 6 files changed, 125 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index da4700b053e..bf62c4ddc62 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod return ret_graph; } +FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { + // select indexed item + // args: tuple of items, index + const std::string op_name = std::string("TupleGetItemTensor"); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + AbstractTuplePtr branches_abs = abstract::CheckArg(op_name, args_spec_list, 0); + AbstractBasePtrList branches = branches_abs->elements(); + + if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa()) { + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr functions = ret_graph->add_parameter(); + auto index = ret_graph->add_parameter(); + + ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); + return ret_graph; + } + + MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; +} + REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { (void)py::class_>(*m, "TupleAdd_") .def(py::init()); @@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TensorSlice_") .def(py::init()); })); + +REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { + (void)py::class_>( + *m, "TupleGetItemTensor_") + .def(py::init()); + })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index 6c4bede82bc..7061eb7441f 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph { FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; }; using TensorSlicePtr = std::shared_ptr; + +class TupleGetItemTensor : public MetaFuncGraph { + public: + explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} + ~TupleGetItemTensor() override = default; + MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { + return lhs.name_ == rhs.name_; + } +}; +using TupleGetItemTensorPtr = std::shared_ptr; } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/prim_statement.cc b/mindspore/ccsrc/operator/prim_statement.cc index e639b58a05b..c297e128e2e 100644 --- a/mindspore/ccsrc/operator/prim_statement.cc +++ b/mindspore/ccsrc/operator/prim_statement.cc @@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: index, branch - if (args_spec_list.size() != 2) { - MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is " - << args_spec_list.size() << "."; - } - AbstractTuplePtr branches_abs = CheckArg(primitive->name(), args_spec_list, 1); + const std::string op_name = primitive->name(); + abstract::CheckArgsSize(op_name, args_spec_list, 2); + (void)CheckArg(op_name, args_spec_list, 0); + AbstractTuplePtr branches_abs = CheckArg(op_name, args_spec_list, 1); AbstractBasePtrList branches = branches_abs->elements(); const size_t maximum_layer_num = 1000; if (branches.size() < 0 || branches.size() > maximum_layer_num) { - MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got " + MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " << branches.size() << " branches."; } - MS_EXCEPTION_IF_NULL(branches[0]); + for (size_t i = 0; i < branches.size(); i++) { + MS_EXCEPTION_IF_NULL(branches[i]); + if (!branches[i]->isa()) { + MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " + << branches[i]->ToString() << " as the " << i << "th element."; + } + } + auto b = branches[0]; for (size_t i = 1; i < branches.size(); i++) { - MS_EXCEPTION_IF_NULL(branches[i]); b = b->Join(branches[i]); } return b; diff --git a/mindspore/ops/composite/base.py b/mindspore/ops/composite/base.py index 4b559d1605b..b0fcbce7874 100644 --- a/mindspore/ops/composite/base.py +++ b/mindspore/ops/composite/base.py @@ -18,13 +18,13 @@ """Basic composite operations.""" from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ - TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_ + TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ from ...common import dtype as mstype from ...common.api import ms_function from .. import functional as F from .. import operations as P -__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_] +__all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] def add_flags(fn, **flags): diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index 540dd28b374..3df117837b0 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice') """_tensor_slice is an metafuncgraph object which will slice a tensor.""" +class _TupleGetItemTensor(base.TupleGetItemTensor_): + """ + Getting item of tuple by tensor index. + + Inputs: + data (tuple): A tuple of items. + index (Tensor): The index in tensor. + Outputs: + Type, is same as the element type of data. + """ + + def __init__(self, name): + base.TupleGetItemTensor_.__init__(self, name) + + def __call__(self, *args): + pass + + +_tuple_get_item_tensor = _TupleGetItemTensor('tuple_get_item_tensor') +"""_tuple_get_item_tensor is an metafuncgraph object which will select indexed item.""" + + @getitem.register("Tuple", "Number") def _tuple_getitem_by_number(data, number_index): """ @@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index): return _tuple_slice(data, slice_index) +@getitem.register("Tuple", "Tensor") +def _tuple_getitem_by_tensor(data, tensor_index): + """ + Getting item out of tuple by tensor index. + + Inputs: + data (tuple): A tuple of items to index. + tensor_index (Tensor): Index to select item. + + Outputs: + Type, is same as the element type of data. + """ + return _tuple_get_item_tensor(data, tensor_index) + + @getitem.register("List", "Number") def _list_getitem_by_number(data, number_index): """ diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index a6c15444e4e..b182396e4fe 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -387,7 +387,38 @@ def test_switch_layer(): ret = F.switch_layer(index, self.layers)(x) * self.z3 return ret + index = Tensor(0) net = SwitchLayerCell() - net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) - C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) - C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + +def test_index_to_switch_layer(): + class Layer1(nn.Cell): + def __init__(self): + super(Layer1, self).__init__() + self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') + def construct(self, x): + return x * self.z1 + + class Layer2(nn.Cell): + def __init__(self): + super(Layer2, self).__init__() + self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') + def construct(self, x): + return x * self.z2 + + class SwitchLayerCell(nn.Cell): + def __init__(self): + super(SwitchLayerCell, self).__init__() + self.layers = (Layer1(), Layer2()) + self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') + def construct(self, index, x): + ret = self.layers[index](x) * self.z3 + return ret + + index = Tensor(0) + net = SwitchLayerCell() + net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))