!29930 Support List Slice

Merge pull request !29930 from lianliguang/support-list-slice
This commit is contained in:
i-robot 2022-02-17 07:24:40 +00:00 committed by Gitee
commit 7c32002e0b
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 280 additions and 22 deletions

View File

@ -164,7 +164,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
return meta_func_graph->isa<prim::Tail>() || meta_func_graph->isa<prim::MakeTupleGradient>() ||
meta_func_graph->isa<prim::MakeListGradient>() || meta_func_graph->isa<prim::TupleAdd>() ||
meta_func_graph->isa<prim::TupleSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
meta_func_graph->isa<prim::SequenceSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
meta_func_graph->isa<prim::ZipOperation>() || meta_func_graph->isa<prim::ListAppend>() ||
meta_func_graph->isa<prim::ListInsert>() || meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>();
}

View File

@ -1006,8 +1006,8 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
MS_LOG(EXCEPTION) << "The argument of SliceMember operator must be a Scalar or None, but got " << member->ToString();
}
void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
int64_t *stop_index, int64_t *step_value) {
void GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &tuple, const AbstractSlicePtr &slice,
int64_t *start_index, int64_t *stop_index, int64_t *step_value) {
MS_EXCEPTION_IF_NULL(tuple);
MS_EXCEPTION_IF_NULL(slice);
MS_EXCEPTION_IF_NULL(start_index);
@ -1053,19 +1053,16 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
}
}
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// slice a tuple
// args: tuple, start index, end index, step
const std::string op_name("TupleSlice");
constexpr size_t arg_size = 2;
abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
auto seq_pair = CheckArgs(args_spec_list);
auto sequence = seq_pair.first;
auto slice = seq_pair.second;
int64_t start_index;
int64_t stop_index;
int64_t step_value;
GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
GenerateTupleSliceParameter(sequence, slice, &start_index, &stop_index, &step_value);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
@ -1073,14 +1070,14 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
(void)ret->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim::kPrimMakeTuple));
elems.push_back(NewValueNode(prim_));
if (step_value > 0) {
for (int64_t index = start_index; index < stop_index; index = index + step_value) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
}
} else {
for (int64_t index = start_index; index > stop_index; index = index + step_value) {
elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
}
}
@ -1088,6 +1085,24 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
return ret;
}
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> TupleSlice::CheckArgs(
const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize("TupleSlice", args_spec_list, arg_size);
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("TupleSlice", args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("TupleSlice", args_spec_list, 1);
return std::make_pair(sequence, slice);
}
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> ListSlice::CheckArgs(
const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize("ListSlice", args_spec_list, arg_size);
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("ListSlice", args_spec_list, 0);
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("ListSlice", args_spec_list, 1);
return std::make_pair(sequence, slice);
}
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// select indexed item
// args: tuple of items, index
@ -1119,6 +1134,11 @@ REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
.def(py::init<std::string &>());
}));
REGISTER_PYBIND_DEFINE(ListSlice_, ([](const py::module *m) {
(void)py::class_<ListSlice, MetaFuncGraph, std::shared_ptr<ListSlice>>(*m, "ListSlice_")
.def(py::init<std::string &>());
}));
namespace {
FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
FuncGraphPtr shard_child = std::make_shared<FuncGraph>();

View File

@ -188,16 +188,41 @@ class TupleAdd : public MetaFuncGraph {
};
using TupleAddPtr = std::shared_ptr<TupleAdd>;
class TupleSlice : public MetaFuncGraph {
class SequenceSlice : public MetaFuncGraph {
public:
explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
~TupleSlice() override = default;
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
explicit SequenceSlice(const std::string &name, const PrimitivePtr &prim, const PrimitivePtr &get_item)
: MetaFuncGraph(name), prim_(prim), get_item_(get_item) {}
~SequenceSlice() override = default;
MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; }
virtual std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) = 0;
private:
PrimitivePtr prim_;
PrimitivePtr get_item_;
};
class TupleSlice : public SequenceSlice {
public:
explicit TupleSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeTuple, prim::kPrimTupleGetItem) {}
~TupleSlice() override = default;
MS_DECLARE_PARENT(TupleSlice, SequenceSlice)
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) override;
};
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
class ListSlice : public SequenceSlice {
public:
explicit ListSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeList, prim::kPrimListGetItem) {}
~ListSlice() override = default;
MS_DECLARE_PARENT(ListSlice, SequenceSlice)
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
const AbstractBasePtrList &args_spec_list) override;
};
class TupleGetItemTensor : public MetaFuncGraph {
public:
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}

View File

@ -21,7 +21,7 @@ from types import FunctionType
from mindspore import context
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, ListSlice_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
from ..primitive import Primitive
@ -29,7 +29,7 @@ from ..operations import _grad_ops
from .. import operations as P
from .. import signature as sig
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_, ListSlice_]
def add_flags(fn=None, **flags):

View File

@ -47,7 +47,31 @@ class _TupleSlice(base.TupleSlice_):
_tuple_slice = _TupleSlice('tuple_slice')
"""_tuple_slice is an metafuncgraph object which will slice a tuple."""
"""_tuple_slice is a metafuncgraph object which will slice a tuple."""
class _ListSlice(base.ListSlice_):
"""
Slices a List.
Inputs:
data (List): A List to be sliced.
s (slice): The index to slice list data.
Outputs:
List, consists of some elements of data.
"""
def __init__(self, name):
"""Initialize _TupleSlice."""
base.ListSlice_.__init__(self, name)
def __call__(self, *args):
pass
_list_slice = _ListSlice('list_slice')
"""_list_slice is a metafuncgraph object which will slice a list."""
class _TupleGetItemTensor(base.TupleGetItemTensor_):
@ -133,6 +157,19 @@ def _list_getitem_by_number(data, number_index):
"""
return F.list_getitem(data, number_index)
@getitem.register("List", "Slice")
def _list_getitem_by_slice(data, slice_index):
"""
Getting item of list by slice index.
Inputs:
data (list): data
slice_index (Slice): Index in slice.
Outputs:
List, element type is the same as the element type of data.
"""
return _list_slice(data, slice_index)
@getitem.register("Dictionary", "String")
def _dict_getitem_by_key(data, key):

View File

@ -35,6 +35,8 @@ using AbstractSlicePtr = abstract::AbstractSlicePtr;
using AbstractTuple = abstract::AbstractTuple;
using AbstractTuplePtr = abstract::AbstractTuplePtr;
using AbstractList = abstract::AbstractList;
using AbstractListPtr = abstract::AbstractListPtr;
using AbstractTensor = abstract::AbstractTensor;
using AbstractTensorPtr = abstract::AbstractTensorPtr;
@ -251,6 +253,157 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
ASSERT_EQ(real, expect);
}
/// Feature: Test list slice
/// Description: The second input is a scalar
/// Expectation: Throw type error
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
AbstractBasePtrList eles;
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
size_t list_size = 6;
for (size_t i = 0; i < list_size; i++) {
eles.push_back(tensor);
}
auto list_tensor = std::make_shared<AbstractList>(eles);
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
AbstractBasePtrList args_spec_list = {list_tensor, start_index};
try {
trace::ClearTraceStack();
engine_->Run(list_graph, args_spec_list);
FAIL() << "Excepted exception: Args type is wrong";
} catch (pybind11::type_error const &err) {
ASSERT_TRUE(true);
} catch (std::runtime_error const &err) {
if (std::strstr(err.what(), "TypeError") != nullptr) {
ASSERT_TRUE(true);
} else {
FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
}
} catch (...) {
FAIL() << "Excepted exception: Args type is wrong";
}
}
/// Feature: Test list slice
/// Description: Test List slice
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice) {
std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
size_t list_size = 6;
for (size_t i = 0; i < list_size; i++) {
eles.push_back(tensor);
}
auto list_tensor = std::make_shared<AbstractList>(eles);
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {list_tensor, slice};
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract list failed.";
}
size_t real = ret->size();
size_t expect = 3;
ASSERT_EQ(real, expect);
}
/// Feature: Test list slice
/// Description: Test List slice the step is none
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
size_t list_size = 6;
for (size_t i = 0; i < list_size; i++) {
eles.push_back(tensor);
}
auto list_tensor = std::make_shared<AbstractList>(eles);
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
auto step = std::make_shared<AbstractNone>();
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {list_tensor, slice};
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract list failed.";
}
size_t real = ret->size();
size_t expect = 4;
ASSERT_EQ(real, expect);
}
/// Feature: Test list slice
/// Description: Test List slice the step is negative
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
size_t list_size = 6;
for (size_t i = 0; i < list_size; i++) {
eles.push_back(tensor);
}
auto list_tensor = std::make_shared<AbstractList>(eles);
auto start_index = std::make_shared<AbstractNone>();
auto stop_index = std::make_shared<AbstractNone>();
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {list_tensor, slice};
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract list failed.";
}
size_t real = ret->size();
size_t expect = 6;
ASSERT_EQ(real, expect);
}
/// Feature: Test list slice
/// Description: Test List slice the step is positive
/// Expectation: No Expectation
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
AbstractBasePtrList eles;
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
size_t list_size = 6;
for (size_t i = 0; i < list_size; i++) {
eles.push_back(tensor);
}
auto list_tensor = std::make_shared<AbstractList>(eles);
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
auto stop_index = std::make_shared<AbstractNone>();
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
AbstractBasePtrList args_spec_list = {list_tensor, slice};
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
if (ret == nullptr) {
FAIL() << "Cast ret to abstract list failed.";
}
size_t real = ret->size();
size_t expect = 5;
ASSERT_EQ(real, expect);
}
TEST_F(TestComposite, test_UnpackCall_3args) {
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);

View File

@ -19,6 +19,7 @@ import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common.api import _cell_graph_executor
from mindspore.nn import Cell
from mindspore import ops
class Net1(Cell):
@ -33,6 +34,17 @@ class Net1(Cell):
return x
class Net2(Cell):
def __init__(self, list1):
super().__init__()
self.list = list1
self.addn = ops.AddN()
def construct(self, x):
x = self.addn(self.list[0::2])
return x
def test_list1():
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
input_me = Tensor(input_np)
@ -45,3 +57,14 @@ def test_list2():
input_me = Tensor(input_np)
net = Net1([1, 2])
_cell_graph_executor.compile(net, input_me)
def test_list_slice():
"""
Feature: Support List Slice
Description: Test List Slice
Expectation: No exception.
"""
input_me = Tensor(8)
net = Net2([Tensor(1), Tensor(2), Tensor(3), Tensor(4), Tensor(5), Tensor(6), Tensor(7)])
_cell_graph_executor.compile(net, input_me)