forked from mindspore-Ecosystem/mindspore
!29930 Support List Slice
Merge pull request !29930 from lianliguang/support-list-slice
This commit is contained in:
commit
7c32002e0b
|
@ -164,7 +164,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap
|
||||||
inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
||||||
return meta_func_graph->isa<prim::Tail>() || meta_func_graph->isa<prim::MakeTupleGradient>() ||
|
return meta_func_graph->isa<prim::Tail>() || meta_func_graph->isa<prim::MakeTupleGradient>() ||
|
||||||
meta_func_graph->isa<prim::MakeListGradient>() || meta_func_graph->isa<prim::TupleAdd>() ||
|
meta_func_graph->isa<prim::MakeListGradient>() || meta_func_graph->isa<prim::TupleAdd>() ||
|
||||||
meta_func_graph->isa<prim::TupleSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
|
meta_func_graph->isa<prim::SequenceSlice>() || meta_func_graph->isa<prim::UnpackCall>() ||
|
||||||
meta_func_graph->isa<prim::ZipOperation>() || meta_func_graph->isa<prim::ListAppend>() ||
|
meta_func_graph->isa<prim::ZipOperation>() || meta_func_graph->isa<prim::ListAppend>() ||
|
||||||
meta_func_graph->isa<prim::ListInsert>() || meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>();
|
meta_func_graph->isa<prim::ListInsert>() || meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1006,8 +1006,8 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
|
||||||
MS_LOG(EXCEPTION) << "The argument of SliceMember operator must be a Scalar or None, but got " << member->ToString();
|
MS_LOG(EXCEPTION) << "The argument of SliceMember operator must be a Scalar or None, but got " << member->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int64_t *start_index,
|
void GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &tuple, const AbstractSlicePtr &slice,
|
||||||
int64_t *stop_index, int64_t *step_value) {
|
int64_t *start_index, int64_t *stop_index, int64_t *step_value) {
|
||||||
MS_EXCEPTION_IF_NULL(tuple);
|
MS_EXCEPTION_IF_NULL(tuple);
|
||||||
MS_EXCEPTION_IF_NULL(slice);
|
MS_EXCEPTION_IF_NULL(slice);
|
||||||
MS_EXCEPTION_IF_NULL(start_index);
|
MS_EXCEPTION_IF_NULL(start_index);
|
||||||
|
@ -1053,19 +1053,16 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSl
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr SequenceSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// slice a tuple
|
// slice a tuple
|
||||||
// args: tuple, start index, end index, step
|
// args: tuple, start index, end index, step
|
||||||
const std::string op_name("TupleSlice");
|
auto seq_pair = CheckArgs(args_spec_list);
|
||||||
constexpr size_t arg_size = 2;
|
auto sequence = seq_pair.first;
|
||||||
abstract::CheckArgsSize(op_name, args_spec_list, arg_size);
|
auto slice = seq_pair.second;
|
||||||
AbstractTuplePtr tuple = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
|
||||||
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>(op_name, args_spec_list, 1);
|
|
||||||
|
|
||||||
int64_t start_index;
|
int64_t start_index;
|
||||||
int64_t stop_index;
|
int64_t stop_index;
|
||||||
int64_t step_value;
|
int64_t step_value;
|
||||||
GenerateTupleSliceParameter(tuple, slice, &start_index, &stop_index, &step_value);
|
GenerateTupleSliceParameter(sequence, slice, &start_index, &stop_index, &step_value);
|
||||||
|
|
||||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||||
|
@ -1073,14 +1070,14 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
||||||
(void)ret->add_parameter();
|
(void)ret->add_parameter();
|
||||||
|
|
||||||
std::vector<AnfNodePtr> elems;
|
std::vector<AnfNodePtr> elems;
|
||||||
elems.push_back(NewValueNode(prim::kPrimMakeTuple));
|
elems.push_back(NewValueNode(prim_));
|
||||||
if (step_value > 0) {
|
if (step_value > 0) {
|
||||||
for (int64_t index = start_index; index < stop_index; index = index + step_value) {
|
for (int64_t index = start_index; index < stop_index; index = index + step_value) {
|
||||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
|
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t index = start_index; index > stop_index; index = index + step_value) {
|
for (int64_t index = start_index; index > stop_index; index = index + step_value) {
|
||||||
elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)}));
|
elems.push_back(ret->NewCNodeInOrder({NewValueNode(get_item_), p_tuple, NewValueNode(index)}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1088,6 +1085,24 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> TupleSlice::CheckArgs(
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
constexpr size_t arg_size = 2;
|
||||||
|
abstract::CheckArgsSize("TupleSlice", args_spec_list, arg_size);
|
||||||
|
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("TupleSlice", args_spec_list, 0);
|
||||||
|
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("TupleSlice", args_spec_list, 1);
|
||||||
|
return std::make_pair(sequence, slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> ListSlice::CheckArgs(
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
constexpr size_t arg_size = 2;
|
||||||
|
abstract::CheckArgsSize("ListSlice", args_spec_list, arg_size);
|
||||||
|
auto sequence = abstract::CheckArg<abstract::AbstractSequence>("ListSlice", args_spec_list, 0);
|
||||||
|
AbstractSlicePtr slice = abstract::CheckArg<AbstractSlice>("ListSlice", args_spec_list, 1);
|
||||||
|
return std::make_pair(sequence, slice);
|
||||||
|
}
|
||||||
|
|
||||||
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||||
// select indexed item
|
// select indexed item
|
||||||
// args: tuple of items, index
|
// args: tuple of items, index
|
||||||
|
@ -1119,6 +1134,11 @@ REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) {
|
||||||
.def(py::init<std::string &>());
|
.def(py::init<std::string &>());
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
REGISTER_PYBIND_DEFINE(ListSlice_, ([](const py::module *m) {
|
||||||
|
(void)py::class_<ListSlice, MetaFuncGraph, std::shared_ptr<ListSlice>>(*m, "ListSlice_")
|
||||||
|
.def(py::init<std::string &>());
|
||||||
|
}));
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
|
FuncGraphPtr GetShard(const AnfNodePtr &shard, const std::vector<AnfNodePtr> &origin_graph_params) {
|
||||||
FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
|
FuncGraphPtr shard_child = std::make_shared<FuncGraph>();
|
||||||
|
|
|
@ -188,16 +188,41 @@ class TupleAdd : public MetaFuncGraph {
|
||||||
};
|
};
|
||||||
using TupleAddPtr = std::shared_ptr<TupleAdd>;
|
using TupleAddPtr = std::shared_ptr<TupleAdd>;
|
||||||
|
|
||||||
class TupleSlice : public MetaFuncGraph {
|
class SequenceSlice : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {}
|
explicit SequenceSlice(const std::string &name, const PrimitivePtr &prim, const PrimitivePtr &get_item)
|
||||||
~TupleSlice() override = default;
|
: MetaFuncGraph(name), prim_(prim), get_item_(get_item) {}
|
||||||
MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph)
|
~SequenceSlice() override = default;
|
||||||
|
MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph)
|
||||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||||
friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; }
|
friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; }
|
||||||
|
virtual std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||||
|
const AbstractBasePtrList &args_spec_list) = 0;
|
||||||
|
|
||||||
|
private:
|
||||||
|
PrimitivePtr prim_;
|
||||||
|
PrimitivePtr get_item_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class TupleSlice : public SequenceSlice {
|
||||||
|
public:
|
||||||
|
explicit TupleSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeTuple, prim::kPrimTupleGetItem) {}
|
||||||
|
~TupleSlice() override = default;
|
||||||
|
MS_DECLARE_PARENT(TupleSlice, SequenceSlice)
|
||||||
|
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||||
|
const AbstractBasePtrList &args_spec_list) override;
|
||||||
};
|
};
|
||||||
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
using TupleSlicePtr = std::shared_ptr<TupleSlice>;
|
||||||
|
|
||||||
|
class ListSlice : public SequenceSlice {
|
||||||
|
public:
|
||||||
|
explicit ListSlice(const std::string &name) : SequenceSlice(name, prim::kPrimMakeList, prim::kPrimListGetItem) {}
|
||||||
|
~ListSlice() override = default;
|
||||||
|
MS_DECLARE_PARENT(ListSlice, SequenceSlice)
|
||||||
|
std::pair<abstract::AbstractSequencePtr, abstract::AbstractSlicePtr> CheckArgs(
|
||||||
|
const AbstractBasePtrList &args_spec_list) override;
|
||||||
|
};
|
||||||
|
|
||||||
class TupleGetItemTensor : public MetaFuncGraph {
|
class TupleGetItemTensor : public MetaFuncGraph {
|
||||||
public:
|
public:
|
||||||
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
|
explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {}
|
||||||
|
|
|
@ -21,7 +21,7 @@ from types import FunctionType
|
||||||
|
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_
|
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, ListSlice_
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||||
from ..primitive import Primitive
|
from ..primitive import Primitive
|
||||||
|
@ -29,7 +29,7 @@ from ..operations import _grad_ops
|
||||||
from .. import operations as P
|
from .. import operations as P
|
||||||
from .. import signature as sig
|
from .. import signature as sig
|
||||||
|
|
||||||
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
|
__all__ = [TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_, ListSlice_]
|
||||||
|
|
||||||
|
|
||||||
def add_flags(fn=None, **flags):
|
def add_flags(fn=None, **flags):
|
||||||
|
|
|
@ -47,7 +47,31 @@ class _TupleSlice(base.TupleSlice_):
|
||||||
|
|
||||||
|
|
||||||
_tuple_slice = _TupleSlice('tuple_slice')
|
_tuple_slice = _TupleSlice('tuple_slice')
|
||||||
"""_tuple_slice is an metafuncgraph object which will slice a tuple."""
|
"""_tuple_slice is a metafuncgraph object which will slice a tuple."""
|
||||||
|
|
||||||
|
|
||||||
|
class _ListSlice(base.ListSlice_):
|
||||||
|
"""
|
||||||
|
Slices a List.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (List): A List to be sliced.
|
||||||
|
s (slice): The index to slice list data.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, consists of some elements of data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
"""Initialize _TupleSlice."""
|
||||||
|
base.ListSlice_.__init__(self, name)
|
||||||
|
|
||||||
|
def __call__(self, *args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_list_slice = _ListSlice('list_slice')
|
||||||
|
"""_list_slice is a metafuncgraph object which will slice a list."""
|
||||||
|
|
||||||
|
|
||||||
class _TupleGetItemTensor(base.TupleGetItemTensor_):
|
class _TupleGetItemTensor(base.TupleGetItemTensor_):
|
||||||
|
@ -133,6 +157,19 @@ def _list_getitem_by_number(data, number_index):
|
||||||
"""
|
"""
|
||||||
return F.list_getitem(data, number_index)
|
return F.list_getitem(data, number_index)
|
||||||
|
|
||||||
|
@getitem.register("List", "Slice")
|
||||||
|
def _list_getitem_by_slice(data, slice_index):
|
||||||
|
"""
|
||||||
|
Getting item of list by slice index.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
data (list): data
|
||||||
|
slice_index (Slice): Index in slice.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
List, element type is the same as the element type of data.
|
||||||
|
"""
|
||||||
|
return _list_slice(data, slice_index)
|
||||||
|
|
||||||
@getitem.register("Dictionary", "String")
|
@getitem.register("Dictionary", "String")
|
||||||
def _dict_getitem_by_key(data, key):
|
def _dict_getitem_by_key(data, key):
|
||||||
|
|
|
@ -35,6 +35,8 @@ using AbstractSlicePtr = abstract::AbstractSlicePtr;
|
||||||
|
|
||||||
using AbstractTuple = abstract::AbstractTuple;
|
using AbstractTuple = abstract::AbstractTuple;
|
||||||
using AbstractTuplePtr = abstract::AbstractTuplePtr;
|
using AbstractTuplePtr = abstract::AbstractTuplePtr;
|
||||||
|
using AbstractList = abstract::AbstractList;
|
||||||
|
using AbstractListPtr = abstract::AbstractListPtr;
|
||||||
|
|
||||||
using AbstractTensor = abstract::AbstractTensor;
|
using AbstractTensor = abstract::AbstractTensor;
|
||||||
using AbstractTensorPtr = abstract::AbstractTensorPtr;
|
using AbstractTensorPtr = abstract::AbstractTensorPtr;
|
||||||
|
@ -251,6 +253,157 @@ TEST_F(TestComposite, test_TupleSlice_arg_slice_step_positive) {
|
||||||
ASSERT_EQ(real, expect);
|
ASSERT_EQ(real, expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Feature: Test list slice
|
||||||
|
/// Description: The second input is a scalar
|
||||||
|
/// Expectation: Throw type error
|
||||||
|
TEST_F(TestComposite, test_ListSlice_arg_one_number) {
|
||||||
|
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||||
|
FuncGraphPtr list_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 3);
|
||||||
|
|
||||||
|
AbstractBasePtrList eles;
|
||||||
|
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||||
|
size_t list_size = 6;
|
||||||
|
for (size_t i = 0; i < list_size; i++) {
|
||||||
|
eles.push_back(tensor);
|
||||||
|
}
|
||||||
|
auto list_tensor = std::make_shared<AbstractList>(eles);
|
||||||
|
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
|
||||||
|
AbstractBasePtrList args_spec_list = {list_tensor, start_index};
|
||||||
|
|
||||||
|
try {
|
||||||
|
trace::ClearTraceStack();
|
||||||
|
engine_->Run(list_graph, args_spec_list);
|
||||||
|
FAIL() << "Excepted exception: Args type is wrong";
|
||||||
|
} catch (pybind11::type_error const &err) {
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
} catch (std::runtime_error const &err) {
|
||||||
|
if (std::strstr(err.what(), "TypeError") != nullptr) {
|
||||||
|
ASSERT_TRUE(true);
|
||||||
|
} else {
|
||||||
|
FAIL() << "Excepted exception: Args type is wrong, message: " << err.what();
|
||||||
|
}
|
||||||
|
} catch (...) {
|
||||||
|
FAIL() << "Excepted exception: Args type is wrong";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test list slice
|
||||||
|
/// Description: Test List slice
|
||||||
|
/// Expectation: No Expectation
|
||||||
|
TEST_F(TestComposite, test_ListSlice_arg_slice) {
|
||||||
|
std::shared_ptr<py::scoped_interpreter> env = parse::python_adapter::set_python_scoped();
|
||||||
|
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||||
|
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||||
|
|
||||||
|
AbstractBasePtrList eles;
|
||||||
|
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||||
|
size_t list_size = 6;
|
||||||
|
for (size_t i = 0; i < list_size; i++) {
|
||||||
|
eles.push_back(tensor);
|
||||||
|
}
|
||||||
|
auto list_tensor = std::make_shared<AbstractList>(eles);
|
||||||
|
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
|
||||||
|
auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(6));
|
||||||
|
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(2));
|
||||||
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
|
AbstractBasePtrList args_spec_list = {list_tensor, slice};
|
||||||
|
|
||||||
|
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
|
||||||
|
if (ret == nullptr) {
|
||||||
|
FAIL() << "Cast ret to abstract list failed.";
|
||||||
|
}
|
||||||
|
size_t real = ret->size();
|
||||||
|
size_t expect = 3;
|
||||||
|
ASSERT_EQ(real, expect);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test list slice
|
||||||
|
/// Description: Test List slice the step is none
|
||||||
|
/// Expectation: No Expectation
|
||||||
|
TEST_F(TestComposite, test_ListSlice_arg_slice_step_none) {
|
||||||
|
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||||
|
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||||
|
|
||||||
|
AbstractBasePtrList eles;
|
||||||
|
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||||
|
size_t list_size = 6;
|
||||||
|
for (size_t i = 0; i < list_size; i++) {
|
||||||
|
eles.push_back(tensor);
|
||||||
|
}
|
||||||
|
auto list_tensor = std::make_shared<AbstractList>(eles);
|
||||||
|
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(1));
|
||||||
|
auto stop_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(5));
|
||||||
|
auto step = std::make_shared<AbstractNone>();
|
||||||
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
|
AbstractBasePtrList args_spec_list = {list_tensor, slice};
|
||||||
|
|
||||||
|
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
|
||||||
|
if (ret == nullptr) {
|
||||||
|
FAIL() << "Cast ret to abstract list failed.";
|
||||||
|
}
|
||||||
|
size_t real = ret->size();
|
||||||
|
size_t expect = 4;
|
||||||
|
ASSERT_EQ(real, expect);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test list slice
|
||||||
|
/// Description: Test List slice the step is negative
|
||||||
|
/// Expectation: No Expectation
|
||||||
|
TEST_F(TestComposite, test_ListSlice_arg_slice_step_negative) {
|
||||||
|
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||||
|
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||||
|
|
||||||
|
AbstractBasePtrList eles;
|
||||||
|
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||||
|
size_t list_size = 6;
|
||||||
|
for (size_t i = 0; i < list_size; i++) {
|
||||||
|
eles.push_back(tensor);
|
||||||
|
}
|
||||||
|
auto list_tensor = std::make_shared<AbstractList>(eles);
|
||||||
|
auto start_index = std::make_shared<AbstractNone>();
|
||||||
|
auto stop_index = std::make_shared<AbstractNone>();
|
||||||
|
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
|
||||||
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
|
AbstractBasePtrList args_spec_list = {list_tensor, slice};
|
||||||
|
|
||||||
|
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
|
||||||
|
if (ret == nullptr) {
|
||||||
|
FAIL() << "Cast ret to abstract list failed.";
|
||||||
|
}
|
||||||
|
size_t real = ret->size();
|
||||||
|
size_t expect = 6;
|
||||||
|
ASSERT_EQ(real, expect);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feature: Test list slice
|
||||||
|
/// Description: Test List slice the step is positive
|
||||||
|
/// Expectation: No Expectation
|
||||||
|
TEST_F(TestComposite, test_ListSlice_arg_slice_step_positive) {
|
||||||
|
MetaFuncGraphPtr list_slice = std::make_shared<prim::ListSlice>("list_slice");
|
||||||
|
FuncGraphPtr list_slice_graph = UTCompositeUtils::MakeFuncGraph(list_slice, 2);
|
||||||
|
|
||||||
|
AbstractBasePtrList eles;
|
||||||
|
auto tensor = UTCompositeUtils::ArrayInt32Of({2, 3, 4});
|
||||||
|
size_t list_size = 6;
|
||||||
|
for (size_t i = 0; i < list_size; i++) {
|
||||||
|
eles.push_back(tensor);
|
||||||
|
}
|
||||||
|
auto list_tensor = std::make_shared<AbstractList>(eles);
|
||||||
|
auto start_index = std::make_shared<AbstractScalar>(static_cast<int64_t>(-2));
|
||||||
|
auto stop_index = std::make_shared<AbstractNone>();
|
||||||
|
auto step = std::make_shared<AbstractScalar>(static_cast<int64_t>(-1));
|
||||||
|
auto slice = std::make_shared<AbstractSlice>(start_index, stop_index, step);
|
||||||
|
AbstractBasePtrList args_spec_list = {list_tensor, slice};
|
||||||
|
|
||||||
|
AbstractListPtr ret = dyn_cast<AbstractList>(engine_->Run(list_slice_graph, args_spec_list).eval_result->abstract());
|
||||||
|
if (ret == nullptr) {
|
||||||
|
FAIL() << "Cast ret to abstract list failed.";
|
||||||
|
}
|
||||||
|
size_t real = ret->size();
|
||||||
|
size_t expect = 5;
|
||||||
|
ASSERT_EQ(real, expect);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TestComposite, test_UnpackCall_3args) {
|
TEST_F(TestComposite, test_UnpackCall_3args) {
|
||||||
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
MetaFuncGraphPtr unPackCallPtr = std::make_shared<prim::UnpackCall>("UnPackCall");
|
||||||
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
|
FuncGraphPtr unPackCallGraphPtr = UTCompositeUtils::MakeFuncGraph(unPackCallPtr, 3);
|
||||||
|
|
|
@ -19,6 +19,7 @@ import mindspore.nn as nn
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore.common.api import _cell_graph_executor
|
from mindspore.common.api import _cell_graph_executor
|
||||||
from mindspore.nn import Cell
|
from mindspore.nn import Cell
|
||||||
|
from mindspore import ops
|
||||||
|
|
||||||
|
|
||||||
class Net1(Cell):
|
class Net1(Cell):
|
||||||
|
@ -33,6 +34,17 @@ class Net1(Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Net2(Cell):
|
||||||
|
def __init__(self, list1):
|
||||||
|
super().__init__()
|
||||||
|
self.list = list1
|
||||||
|
self.addn = ops.AddN()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.addn(self.list[0::2])
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def test_list1():
|
def test_list1():
|
||||||
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
input_np = np.random.randn(2, 3, 4, 5).astype(np.float32)
|
||||||
input_me = Tensor(input_np)
|
input_me = Tensor(input_np)
|
||||||
|
@ -45,3 +57,14 @@ def test_list2():
|
||||||
input_me = Tensor(input_np)
|
input_me = Tensor(input_np)
|
||||||
net = Net1([1, 2])
|
net = Net1([1, 2])
|
||||||
_cell_graph_executor.compile(net, input_me)
|
_cell_graph_executor.compile(net, input_me)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_slice():
|
||||||
|
"""
|
||||||
|
Feature: Support List Slice
|
||||||
|
Description: Test List Slice
|
||||||
|
Expectation: No exception.
|
||||||
|
"""
|
||||||
|
input_me = Tensor(8)
|
||||||
|
net = Net2([Tensor(1), Tensor(2), Tensor(3), Tensor(4), Tensor(5), Tensor(6), Tensor(7)])
|
||||||
|
_cell_graph_executor.compile(net, input_me)
|
||||||
|
|
Loading…
Reference in New Issue