Support list clear,count,extend,reverse Method.

This commit is contained in:
broccoli857 2022-08-12 11:21:24 +00:00
parent dceb555b40
commit 147f126c6d
11 changed files with 842 additions and 6 deletions

View File

@ -16,7 +16,6 @@
#include "frontend/operator/composite/list_operation.h"
#include <vector>
#include <string>
#include <memory>
@ -131,5 +130,112 @@ FuncGraphPtr ListPop::GenerateFuncGraph(const abstract::AbstractBasePtrList &arg
ret->set_output(out);
return ret;
}
FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListClear", args_list, 1);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("clear");
(void)ret->add_parameter();
ret->set_output(ret->NewCNode({NewValueNode(prim::kPrimMakeList)}));
return ret;
}
FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListExtend", args_list, 2);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("extend");
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim::kPrimMakeList));
AddNodeToElems(args_list[0], ret, &elems);
AddNodeToElems(args_list[1], ret, &elems);
auto out = ret->NewCNode(elems);
ret->set_output(out);
return ret;
}
void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg);
MS_EXCEPTION_IF_NULL(arg_list);
int64_t len = SizeToLong(arg_list->size());
AnfNodePtr arg_node = ret->add_parameter();
for (int64_t i = 0; i < len; ++i) {
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
elems->push_back(value);
}
}
FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListReverse", args_list, 1);
auto &arg0 = args_list[0];
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg0);
MS_EXCEPTION_IF_NULL(arg_list);
int64_t arg_length = SizeToLong(arg_list->size());
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("reverse");
AnfNodePtr arg0_node = ret->add_parameter();
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim::kPrimMakeList));
for (int64_t i = arg_length - 1; i >= 0; --i) {
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToLong(i))}));
}
ret->set_output(ret->NewCNode(elems));
return ret;
}
FuncGraphPtr ListCount::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
const size_t list_count_args_size = 2;
abstract::CheckArgsSize("ListCount", args_list, list_count_args_size);
auto &arg0 = args_list[0];
auto &arg1 = args_list[1];
auto arg0_list = dyn_cast_ptr<abstract::AbstractList>(arg0);
MS_EXCEPTION_IF_NULL(arg0_list);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("count");
(void)ret->add_parameter();
(void)ret->add_parameter();
ValuePtr count_value = arg1->BuildValue();
const auto &values = arg0_list->elements();
int64_t count_ = 0;
for (auto value_ : values) {
if (ComparesTwoValues(count_value, value_->BuildValue())) {
count_++;
}
}
auto out = NewValueNode(MakeValue(count_));
ret->set_output(out);
return ret;
}
bool ListCount::ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &list_value) {
MS_EXCEPTION_IF_NULL(count_value);
MS_EXCEPTION_IF_NULL(list_value);
if (!count_value->IsSameTypeId(list_value->tid())) {
return false;
}
if (count_value->isa<AnyValue>() || list_value->isa<AnyValue>()) {
MS_EXCEPTION(NotSupportError) << "The list count not support " << count_value->type_name() << " type now.";
} else if (count_value->isa<tensor::Tensor>()) {
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_value->cast_ptr<tensor::Tensor>());
} else {
return *count_value == *list_value;
}
}
} // namespace prim
} // namespace mindspore

View File

@ -19,6 +19,7 @@
#include <string>
#include <memory>
#include <vector>
#include "ir/meta_func_graph.h"
@ -66,6 +67,64 @@ class ListPop : public MetaFuncGraph {
friend bool operator==(const ListPop &lhs, const ListPop &rhs) { return lhs.name_ == rhs.name_; }
};
using ListPopPtr = std::shared_ptr<ListPop>;
class ListClear : public MetaFuncGraph {
public:
explicit ListClear(const std::string &name) : MetaFuncGraph(name) {}
~ListClear() override = default;
MS_DECLARE_PARENT(ListClear, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const ListClear &list_clear) {
os << list_clear.name_;
return os;
}
friend bool operator==(const ListClear &lhs, const ListClear &rhs) { return lhs.name_ == rhs.name_; }
};
using ListClearPtr = std::shared_ptr<ListClear>;
class ListExtend : public MetaFuncGraph {
public:
explicit ListExtend(const std::string &name) : MetaFuncGraph(name) {}
~ListExtend() override = default;
MS_DECLARE_PARENT(ListExtend, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const ListExtend &list_extend) {
os << list_extend.name_;
return os;
}
friend bool operator==(const ListExtend &lhs, const ListExtend &rhs) { return lhs.name_ == rhs.name_; }
void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
};
using ListExtendPtr = std::shared_ptr<ListExtend>;
class ListReverse : public MetaFuncGraph {
public:
explicit ListReverse(const std::string &name) : MetaFuncGraph(name) {}
~ListReverse() override = default;
MS_DECLARE_PARENT(ListReverse, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const ListReverse &list_reverse) {
os << list_reverse.name_;
return os;
}
friend bool operator==(const ListReverse &lhs, const ListReverse &rhs) { return lhs.name_ == rhs.name_; }
};
using ListReversePtr = std::shared_ptr<ListReverse>;
class ListCount : public MetaFuncGraph {
public:
explicit ListCount(const std::string &name) : MetaFuncGraph(name) {}
~ListCount() override = default;
MS_DECLARE_PARENT(ListCount, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
friend std::ostream &operator<<(std::ostream &os, const ListCount &list_count) {
os << list_count.name_;
return os;
}
friend bool operator==(const ListCount &lhs, const ListCount &rhs) { return lhs.name_ == rhs.name_; }
bool ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &list_value);
};
using ListCountPtr = std::shared_ptr<ListCount>;
} // namespace prim
} // namespace mindspore

View File

@ -88,6 +88,22 @@ REGISTER_PYBIND_WITH_PARENT_NAME(
(void)py::class_<ListPop, MetaFuncGraph, std::shared_ptr<ListPop>>(*m, "ListPop_")
.def(py::init<const std::string &>());
// Reg ListClear
(void)py::class_<ListClear, MetaFuncGraph, std::shared_ptr<ListClear>>(*m, "ListClear_")
.def(py::init<const std::string &>());
// Reg ListReverse
(void)py::class_<ListReverse, MetaFuncGraph, std::shared_ptr<ListReverse>>(*m, "ListReverse_")
.def(py::init<const std::string &>());
// Reg ListExtend
(void)py::class_<ListExtend, MetaFuncGraph, std::shared_ptr<ListExtend>>(*m, "ListExtend_")
.def(py::init<const std::string &>());
// Reg ListCount
(void)py::class_<ListCount, MetaFuncGraph, std::shared_ptr<ListCount>>(*m, "ListCount_")
.def(py::init<const std::string &>());
// Reg MapPy
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"), py::arg("ops"))

View File

@ -136,8 +136,12 @@ BuiltInTypeMap &GetMethodMap() {
{"append", std::string("list_append")}, // C.list_append
{"__bool__", std::string("list_bool")}, // C.list_bool
{"__ms_hasnext__", std::string("list_hasnext")},
{"insert", std::string("list_insert")}, // C.list_insert
{"pop", std::string("list_pop")} // C.list_pop
{"insert", std::string("list_insert")}, // C.list_insert
{"pop", std::string("list_pop")}, // C.list_pop
{"clear", std::string("list_clear")}, // C.list_clear
{"reverse", std::string("list_reverse")}, // C.list_reverse
{"extend", std::string("list_extend")}, // C.list_extend
{"count", std::string("list_count")} // C.list_count
}},
{kObjectTypeDictionary,
{

View File

@ -92,7 +92,7 @@ SYNTAX_UNSUPPORTED_NAMESPACE = 4 # Unsupported namespace
# Process expr statement white list
# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
parse_expr_statement_white_list = (
"append", "insert",
"append", "insert", "clear", "reverse", "extend",
)
_builtin_function_or_method_type = type(abs)

View File

@ -25,7 +25,8 @@ from ...ops import functional as F
from ...ops import operations as P
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
zeros_like, ones_like, repeat_elements
from ...ops.composite.base import _append, _insert, _pop
from ...ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
_count, _extend
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
from ...ops.operations.math_ops import Median
@ -2452,6 +2453,22 @@ def list_pop(self_, index=-1):
return self_, pop_val
def list_clear(self_):
return _list_clear(self_)
def list_reverse(self_):
return _reverse(self_)
def list_extend(self_, obj):
return _extend(self_, obj)
def list_count(self_, value):
return _count(self_, value)
def dict_get(self_, key_index, default_value=None):
"""Get value by key from dict"""
return F.dict_getitem(self_, key_index, default_value)

View File

@ -22,7 +22,8 @@ import mindspore as ms
from mindspore import context
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
ListClear_, ListReverse_, ListExtend_, ListCount_
from ...common import dtype as mstype
from ...common.api import ms_function, _pynative_executor, _wrap_func
from ..primitive import Primitive
@ -913,6 +914,82 @@ class _ListPop(ListPop_):
_pop = _ListPop("pop")
class _ListClear(ListClear_):
"""
A metafuncgraph class that clear the list.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _ListClear."""
ListClear_.__init__(self, name)
def __call__(self, *args):
pass
_list_clear = _ListClear("clear")
class _ListReverse(ListReverse_):
"""
A metafuncgraph class that reverse the list.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _ListReverse."""
ListReverse_.__init__(self, name)
def __call__(self, *args):
pass
_reverse = _ListReverse("reverse")
class _ListExtend(ListExtend_):
"""
A metafuncgraph class that append another list to the end of the list.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _ListExtend."""
ListExtend_.__init__(self, name)
def __call__(self, *args):
pass
_extend = _ListExtend("extend")
class _ListCount(ListCount_):
"""
A metafuncgraph class that count the number of times an element appears in list.
Args:
name (str): The name of the metafuncgraph object.
"""
def __init__(self, name):
"""Initialize _ListCount."""
ListCount_.__init__(self, name)
def __call__(self, *args):
pass
_count = _ListCount("count")
class _Tail(Tail_):
"""
A metafuncgraph class that generates tail elements of the tuple.

View File

@ -0,0 +1,104 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_list_clear """
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_clear_1():
"""
Feature: list clear.
Description: support list clear.
Expectation: No exception.
"""
@ms_function
def list_net_1():
x = [1, 2, 3, 4]
x.clear()
return Tensor(x)
out = list_net_1()
assert np.all(out.asnumpy() == ())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_clear_2():
"""
Feature: list clear.
Description: support list clear.
Expectation: No exception.
"""
@ms_function
def list_net_2():
aa = 20
x = ['a', ['bb', '2', 3], aa, 4]
x.clear()
return Tensor(x)
out = list_net_2()
assert np.all(out.asnumpy() == ())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_clear_3():
"""
Feature: list clear.
Description: support list clear.
Expectation: No exception.
"""
@ms_function
def list_net_3():
aa = 20
bb = Tensor(1)
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, [1, 2], Tensor(1)]
x.clear()
return Tensor(x)
out = list_net_3()
assert np.all(out.asnumpy() == ())
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_clear_4():
"""
Feature: list clear.
Description: support list clear.
Expectation: No exception.
"""
@ms_function
def list_net_4():
x = []
x.clear()
return Tensor(x)
out = list_net_4()
assert np.all(out.asnumpy() == ())

View File

@ -0,0 +1,235 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_list_count """
import pytest
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_1():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_1():
x = [1, 2, 3, 4]
res = x.count(1)
return Tensor(res)
out = list_net_1()
assert out == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_2():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_2():
x = [1, 2, 3, 4]
res = x.count(0)
return Tensor(res)
out = list_net_2()
assert out == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_3():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_3():
aa = 20
x = ['a', 'b', aa, 4]
res = x.count(aa)
return Tensor(res)
out = list_net_3()
assert out == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_4():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_4():
aa = 20
bb = 'b'
x = ['a', 'b', aa, 4, bb]
res = x.count(bb)
return Tensor(res)
out = list_net_4()
assert out == 2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_5():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_5():
aa = 20
x = ['a', ['bb', '2', 3], aa, 4]
res = x.count(['bb', 2, 3])
return Tensor(res)
out = list_net_5()
assert out == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_6():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_6():
aa = 20
x = ['a', ('Michael', 'Bob', '2'), aa, 4]
res = x.count(('Michael', 'Bob', 2))
return Tensor(res)
out = list_net_6()
assert out == 0
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_7():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_7():
aa = 20
bb = Tensor(1)
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, [1, 2], Tensor(1)]
res = x.count(bb)
return Tensor(res)
out = list_net_7()
assert out == 2
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_8():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_8():
aa = 20
bb = {'Michael': 1, 'Bob': 'bb', '2': [1, 2]}
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, 4, bb]
res = x.count(bb)
return Tensor(res)
out = list_net_8()
assert out == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_9():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_9():
aa = 20
bb = Tensor([10, 20, True])
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, Tensor([10, 20, 2]), bb]
res = x.count(bb)
return Tensor(res)
out = list_net_9()
assert out == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_count_10():
"""
Feature: list count.
Description: support list count.
Expectation: No exception.
"""
@ms_function
def list_net_10(aa, bb):
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, aa+bb, bb]
res = x.count(aa + bb)
return Tensor(res)
aa = Tensor(20)
bb = Tensor(10)
with pytest.raises(RuntimeError):
out = list_net_10(aa, bb)
print(out)

View File

@ -0,0 +1,114 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_list_extend """
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_extend_1():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@ms_function
def list_net_1():
x = [1, 2, 3, 4]
y = [5, 6, 7]
x.extend(y)
return x
out = list_net_1()
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_extend_2():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@ms_function
def list_net_2():
aa = 20
x = [1, 2, 3, 4]
y = [('bb', '2', 3)]
z = [aa]
x.extend(y)
x.extend(z)
return x
out = list_net_2()
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 20))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_extend_3():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@ms_function
def list_net_3():
aa = 20
bb = Tensor(1)
cc = 'Bob'
x = [1, 2, 3, 4]
y = [('bb', '2', 3), cc]
z = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, (1, 2), Tensor(1)]
x.extend(y)
x.extend(z)
return x
out = list_net_3()
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
20, 4, Tensor(1), (1, 2), Tensor(1)))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_extend_4():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@ms_function
def list_net_4():
x = []
y = []
x.extend(y)
return x
out = list_net_4()
assert np.all(out == ())

View File

@ -0,0 +1,104 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_list_reverse """
import pytest
import numpy as np
from mindspore import Tensor, ms_function, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_reverse_1():
"""
Feature: list reverse.
Description: support list reverse.
Expectation: No exception.
"""
@ms_function
def list_net_1():
x = [1, 2, 3, 4]
x.reverse()
return x
out = list_net_1()
assert np.all(out == (4, 3, 2, 1))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_reverse_2():
"""
Feature: list reverse.
Description: support list reverse.
Expectation: No exception.
"""
@ms_function
def list_net_2():
aa = 20
x = ['a', ('bb', '2', 3), aa, 4]
x.reverse()
return x
out = list_net_2()
assert np.all(out == (4, 20, ('bb', '2', 3), 'a'))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_reverse_3():
"""
Feature: list reverse.
Description: support list reverse.
Expectation: No exception.
"""
@ms_function
def list_net_3():
aa = 20
bb = Tensor(1)
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, (1, 2), Tensor(1)]
x.reverse()
return x
out = list_net_3()
assert np.all(out == (Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_list_reverse_4():
"""
Feature: list reverse.
Description: support list reverse.
Expectation: No exception.
"""
@ms_function
def list_net_4():
x = []
x.reverse()
return x
out = list_net_4()
assert np.all(out == ())