Support list clear,count,extend,reverse Method.
This commit is contained in:
parent
dceb555b40
commit
147f126c6d
|
@ -16,7 +16,6 @@
|
|||
|
||||
#include "frontend/operator/composite/list_operation.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
@ -131,5 +130,112 @@ FuncGraphPtr ListPop::GenerateFuncGraph(const abstract::AbstractBasePtrList &arg
|
|||
ret->set_output(out);
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
abstract::CheckArgsSize("ListClear", args_list, 1);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("clear");
|
||||
(void)ret->add_parameter();
|
||||
|
||||
ret->set_output(ret->NewCNode({NewValueNode(prim::kPrimMakeList)}));
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
abstract::CheckArgsSize("ListExtend", args_list, 2);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("extend");
|
||||
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
AddNodeToElems(args_list[0], ret, &elems);
|
||||
AddNodeToElems(args_list[1], ret, &elems);
|
||||
|
||||
auto out = ret->NewCNode(elems);
|
||||
ret->set_output(out);
|
||||
return ret;
|
||||
}
|
||||
|
||||
void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
|
||||
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg);
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
int64_t len = SizeToLong(arg_list->size());
|
||||
AnfNodePtr arg_node = ret->add_parameter();
|
||||
for (int64_t i = 0; i < len; ++i) {
|
||||
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
|
||||
elems->push_back(value);
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
abstract::CheckArgsSize("ListReverse", args_list, 1);
|
||||
auto &arg0 = args_list[0];
|
||||
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg0);
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
int64_t arg_length = SizeToLong(arg_list->size());
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("reverse");
|
||||
AnfNodePtr arg0_node = ret->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
for (int64_t i = arg_length - 1; i >= 0; --i) {
|
||||
elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg0_node, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
|
||||
ret->set_output(ret->NewCNode(elems));
|
||||
return ret;
|
||||
}
|
||||
|
||||
FuncGraphPtr ListCount::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
const size_t list_count_args_size = 2;
|
||||
abstract::CheckArgsSize("ListCount", args_list, list_count_args_size);
|
||||
auto &arg0 = args_list[0];
|
||||
auto &arg1 = args_list[1];
|
||||
|
||||
auto arg0_list = dyn_cast_ptr<abstract::AbstractList>(arg0);
|
||||
MS_EXCEPTION_IF_NULL(arg0_list);
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("count");
|
||||
(void)ret->add_parameter();
|
||||
(void)ret->add_parameter();
|
||||
|
||||
ValuePtr count_value = arg1->BuildValue();
|
||||
const auto &values = arg0_list->elements();
|
||||
int64_t count_ = 0;
|
||||
for (auto value_ : values) {
|
||||
if (ComparesTwoValues(count_value, value_->BuildValue())) {
|
||||
count_++;
|
||||
}
|
||||
}
|
||||
|
||||
auto out = NewValueNode(MakeValue(count_));
|
||||
ret->set_output(out);
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool ListCount::ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &list_value) {
|
||||
MS_EXCEPTION_IF_NULL(count_value);
|
||||
MS_EXCEPTION_IF_NULL(list_value);
|
||||
|
||||
if (!count_value->IsSameTypeId(list_value->tid())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (count_value->isa<AnyValue>() || list_value->isa<AnyValue>()) {
|
||||
MS_EXCEPTION(NotSupportError) << "The list count not support " << count_value->type_name() << " type now.";
|
||||
} else if (count_value->isa<tensor::Tensor>()) {
|
||||
return count_value->cast_ptr<tensor::Tensor>()->ValueEqual(*list_value->cast_ptr<tensor::Tensor>());
|
||||
} else {
|
||||
return *count_value == *list_value;
|
||||
}
|
||||
}
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "ir/meta_func_graph.h"
|
||||
|
||||
|
@ -66,6 +67,64 @@ class ListPop : public MetaFuncGraph {
|
|||
friend bool operator==(const ListPop &lhs, const ListPop &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using ListPopPtr = std::shared_ptr<ListPop>;
|
||||
|
||||
class ListClear : public MetaFuncGraph {
|
||||
public:
|
||||
explicit ListClear(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~ListClear() override = default;
|
||||
MS_DECLARE_PARENT(ListClear, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const ListClear &list_clear) {
|
||||
os << list_clear.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const ListClear &lhs, const ListClear &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using ListClearPtr = std::shared_ptr<ListClear>;
|
||||
|
||||
class ListExtend : public MetaFuncGraph {
|
||||
public:
|
||||
explicit ListExtend(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~ListExtend() override = default;
|
||||
MS_DECLARE_PARENT(ListExtend, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const ListExtend &list_extend) {
|
||||
os << list_extend.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const ListExtend &lhs, const ListExtend &rhs) { return lhs.name_ == rhs.name_; }
|
||||
void AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems);
|
||||
};
|
||||
using ListExtendPtr = std::shared_ptr<ListExtend>;
|
||||
|
||||
class ListReverse : public MetaFuncGraph {
|
||||
public:
|
||||
explicit ListReverse(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~ListReverse() override = default;
|
||||
MS_DECLARE_PARENT(ListReverse, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const ListReverse &list_reverse) {
|
||||
os << list_reverse.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const ListReverse &lhs, const ListReverse &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using ListReversePtr = std::shared_ptr<ListReverse>;
|
||||
|
||||
class ListCount : public MetaFuncGraph {
|
||||
public:
|
||||
explicit ListCount(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~ListCount() override = default;
|
||||
MS_DECLARE_PARENT(ListCount, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override;
|
||||
friend std::ostream &operator<<(std::ostream &os, const ListCount &list_count) {
|
||||
os << list_count.name_;
|
||||
return os;
|
||||
}
|
||||
friend bool operator==(const ListCount &lhs, const ListCount &rhs) { return lhs.name_ == rhs.name_; }
|
||||
bool ComparesTwoValues(const ValuePtr &count_value, const ValuePtr &list_value);
|
||||
};
|
||||
using ListCountPtr = std::shared_ptr<ListCount>;
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -88,6 +88,22 @@ REGISTER_PYBIND_WITH_PARENT_NAME(
|
|||
(void)py::class_<ListPop, MetaFuncGraph, std::shared_ptr<ListPop>>(*m, "ListPop_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg ListClear
|
||||
(void)py::class_<ListClear, MetaFuncGraph, std::shared_ptr<ListClear>>(*m, "ListClear_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg ListReverse
|
||||
(void)py::class_<ListReverse, MetaFuncGraph, std::shared_ptr<ListReverse>>(*m, "ListReverse_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg ListExtend
|
||||
(void)py::class_<ListExtend, MetaFuncGraph, std::shared_ptr<ListExtend>>(*m, "ListExtend_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg ListCount
|
||||
(void)py::class_<ListCount, MetaFuncGraph, std::shared_ptr<ListCount>>(*m, "ListCount_")
|
||||
.def(py::init<const std::string &>());
|
||||
|
||||
// Reg MapPy
|
||||
(void)py::class_<MapPy, MetaFuncGraph, std::shared_ptr<MapPy>>(*m, "Map_")
|
||||
.def(py::init<bool, std::shared_ptr<MultitypeFuncGraph>>(), py::arg("reverse"), py::arg("ops"))
|
||||
|
|
|
@ -136,8 +136,12 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"append", std::string("list_append")}, // C.list_append
|
||||
{"__bool__", std::string("list_bool")}, // C.list_bool
|
||||
{"__ms_hasnext__", std::string("list_hasnext")},
|
||||
{"insert", std::string("list_insert")}, // C.list_insert
|
||||
{"pop", std::string("list_pop")} // C.list_pop
|
||||
{"insert", std::string("list_insert")}, // C.list_insert
|
||||
{"pop", std::string("list_pop")}, // C.list_pop
|
||||
{"clear", std::string("list_clear")}, // C.list_clear
|
||||
{"reverse", std::string("list_reverse")}, // C.list_reverse
|
||||
{"extend", std::string("list_extend")}, // C.list_extend
|
||||
{"count", std::string("list_count")} // C.list_count
|
||||
}},
|
||||
{kObjectTypeDictionary,
|
||||
{
|
||||
|
|
|
@ -92,7 +92,7 @@ SYNTAX_UNSUPPORTED_NAMESPACE = 4 # Unsupported namespace
|
|||
# Process expr statement white list
|
||||
# Add as needed, eg: "clear", "extend", "insert", "remove", "reverse"
|
||||
parse_expr_statement_white_list = (
|
||||
"append", "insert",
|
||||
"append", "insert", "clear", "reverse", "extend",
|
||||
)
|
||||
|
||||
_builtin_function_or_method_type = type(abs)
|
||||
|
|
|
@ -25,7 +25,8 @@ from ...ops import functional as F
|
|||
from ...ops import operations as P
|
||||
from ...ops.composite import tail, core, MultitypeFuncGraph, env_get, hyper_add, \
|
||||
zeros_like, ones_like, repeat_elements
|
||||
from ...ops.composite.base import _append, _insert, _pop
|
||||
from ...ops.composite.base import _append, _insert, _pop, _list_clear, _reverse, \
|
||||
_count, _extend
|
||||
from ...ops.composite.multitype_ops import _constexpr_utils as const_utils
|
||||
from ...ops.composite.multitype_ops import _compile_utils as compile_utils
|
||||
from ...ops.operations.math_ops import Median
|
||||
|
@ -2452,6 +2453,22 @@ def list_pop(self_, index=-1):
|
|||
return self_, pop_val
|
||||
|
||||
|
||||
def list_clear(self_):
|
||||
return _list_clear(self_)
|
||||
|
||||
|
||||
def list_reverse(self_):
|
||||
return _reverse(self_)
|
||||
|
||||
|
||||
def list_extend(self_, obj):
|
||||
return _extend(self_, obj)
|
||||
|
||||
|
||||
def list_count(self_, value):
|
||||
return _count(self_, value)
|
||||
|
||||
|
||||
def dict_get(self_, key_index, default_value=None):
|
||||
"""Get value by key from dict"""
|
||||
return F.dict_getitem(self_, key_index, default_value)
|
||||
|
|
|
@ -22,7 +22,8 @@ import mindspore as ms
|
|||
from mindspore import context
|
||||
from ..._c_expression import GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, Shard_, \
|
||||
TupleAdd_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_, ListInsert_, \
|
||||
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_
|
||||
SequenceSliceGetItem_, ListSliceSetItem_, VmapOperation_, TaylorOperation_, ListPop_, \
|
||||
ListClear_, ListReverse_, ListExtend_, ListCount_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_executor, _wrap_func
|
||||
from ..primitive import Primitive
|
||||
|
@ -913,6 +914,82 @@ class _ListPop(ListPop_):
|
|||
_pop = _ListPop("pop")
|
||||
|
||||
|
||||
class _ListClear(ListClear_):
|
||||
"""
|
||||
A metafuncgraph class that clear the list.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ListClear."""
|
||||
ListClear_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_list_clear = _ListClear("clear")
|
||||
|
||||
|
||||
class _ListReverse(ListReverse_):
|
||||
"""
|
||||
A metafuncgraph class that reverse the list.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ListReverse."""
|
||||
ListReverse_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_reverse = _ListReverse("reverse")
|
||||
|
||||
|
||||
class _ListExtend(ListExtend_):
|
||||
"""
|
||||
A metafuncgraph class that append another list to the end of the list.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ListExtend."""
|
||||
ListExtend_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_extend = _ListExtend("extend")
|
||||
|
||||
|
||||
class _ListCount(ListCount_):
|
||||
"""
|
||||
A metafuncgraph class that count the number of times an element appears in list.
|
||||
|
||||
Args:
|
||||
name (str): The name of the metafuncgraph object.
|
||||
"""
|
||||
|
||||
def __init__(self, name):
|
||||
"""Initialize _ListCount."""
|
||||
ListCount_.__init__(self, name)
|
||||
|
||||
def __call__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
_count = _ListCount("count")
|
||||
|
||||
|
||||
class _Tail(Tail_):
|
||||
"""
|
||||
A metafuncgraph class that generates tail elements of the tuple.
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_list_clear """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_clear_1():
|
||||
"""
|
||||
Feature: list clear.
|
||||
Description: support list clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_1():
|
||||
x = [1, 2, 3, 4]
|
||||
x.clear()
|
||||
return Tensor(x)
|
||||
out = list_net_1()
|
||||
assert np.all(out.asnumpy() == ())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_clear_2():
|
||||
"""
|
||||
Feature: list clear.
|
||||
Description: support list clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_2():
|
||||
aa = 20
|
||||
x = ['a', ['bb', '2', 3], aa, 4]
|
||||
x.clear()
|
||||
return Tensor(x)
|
||||
out = list_net_2()
|
||||
assert np.all(out.asnumpy() == ())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_clear_3():
|
||||
"""
|
||||
Feature: list clear.
|
||||
Description: support list clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_3():
|
||||
aa = 20
|
||||
bb = Tensor(1)
|
||||
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, [1, 2], Tensor(1)]
|
||||
x.clear()
|
||||
return Tensor(x)
|
||||
out = list_net_3()
|
||||
assert np.all(out.asnumpy() == ())
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_clear_4():
|
||||
"""
|
||||
Feature: list clear.
|
||||
Description: support list clear.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_4():
|
||||
x = []
|
||||
x.clear()
|
||||
return Tensor(x)
|
||||
out = list_net_4()
|
||||
assert np.all(out.asnumpy() == ())
|
|
@ -0,0 +1,235 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_list_count """
|
||||
import pytest
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_1():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_1():
|
||||
x = [1, 2, 3, 4]
|
||||
res = x.count(1)
|
||||
return Tensor(res)
|
||||
out = list_net_1()
|
||||
assert out == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_2():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_2():
|
||||
x = [1, 2, 3, 4]
|
||||
res = x.count(0)
|
||||
return Tensor(res)
|
||||
out = list_net_2()
|
||||
assert out == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_3():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_3():
|
||||
aa = 20
|
||||
x = ['a', 'b', aa, 4]
|
||||
res = x.count(aa)
|
||||
return Tensor(res)
|
||||
out = list_net_3()
|
||||
assert out == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_4():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_4():
|
||||
aa = 20
|
||||
bb = 'b'
|
||||
x = ['a', 'b', aa, 4, bb]
|
||||
res = x.count(bb)
|
||||
return Tensor(res)
|
||||
out = list_net_4()
|
||||
assert out == 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_5():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_5():
|
||||
aa = 20
|
||||
x = ['a', ['bb', '2', 3], aa, 4]
|
||||
res = x.count(['bb', 2, 3])
|
||||
return Tensor(res)
|
||||
out = list_net_5()
|
||||
assert out == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_6():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_6():
|
||||
aa = 20
|
||||
x = ['a', ('Michael', 'Bob', '2'), aa, 4]
|
||||
res = x.count(('Michael', 'Bob', 2))
|
||||
return Tensor(res)
|
||||
out = list_net_6()
|
||||
assert out == 0
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_7():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_7():
|
||||
aa = 20
|
||||
bb = Tensor(1)
|
||||
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, [1, 2], Tensor(1)]
|
||||
res = x.count(bb)
|
||||
return Tensor(res)
|
||||
out = list_net_7()
|
||||
assert out == 2
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_8():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_8():
|
||||
aa = 20
|
||||
bb = {'Michael': 1, 'Bob': 'bb', '2': [1, 2]}
|
||||
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, 4, bb]
|
||||
res = x.count(bb)
|
||||
return Tensor(res)
|
||||
out = list_net_8()
|
||||
assert out == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_9():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_9():
|
||||
aa = 20
|
||||
bb = Tensor([10, 20, True])
|
||||
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, Tensor([10, 20, 2]), bb]
|
||||
res = x.count(bb)
|
||||
return Tensor(res)
|
||||
out = list_net_9()
|
||||
assert out == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_count_10():
|
||||
"""
|
||||
Feature: list count.
|
||||
Description: support list count.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_10(aa, bb):
|
||||
x = ['a', {'Michael': 1, 'Bob': 'bb', '2': [1, '2']}, aa, aa+bb, bb]
|
||||
res = x.count(aa + bb)
|
||||
return Tensor(res)
|
||||
|
||||
aa = Tensor(20)
|
||||
bb = Tensor(10)
|
||||
with pytest.raises(RuntimeError):
|
||||
out = list_net_10(aa, bb)
|
||||
print(out)
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_list_extend """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_extend_1():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_1():
|
||||
x = [1, 2, 3, 4]
|
||||
y = [5, 6, 7]
|
||||
x.extend(y)
|
||||
return x
|
||||
out = list_net_1()
|
||||
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_extend_2():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_2():
|
||||
aa = 20
|
||||
x = [1, 2, 3, 4]
|
||||
y = [('bb', '2', 3)]
|
||||
z = [aa]
|
||||
x.extend(y)
|
||||
x.extend(z)
|
||||
return x
|
||||
out = list_net_2()
|
||||
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 20))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_extend_3():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_3():
|
||||
aa = 20
|
||||
bb = Tensor(1)
|
||||
cc = 'Bob'
|
||||
x = [1, 2, 3, 4]
|
||||
y = [('bb', '2', 3), cc]
|
||||
z = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, (1, 2), Tensor(1)]
|
||||
x.extend(y)
|
||||
x.extend(z)
|
||||
return x
|
||||
out = list_net_3()
|
||||
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
|
||||
20, 4, Tensor(1), (1, 2), Tensor(1)))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_extend_4():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_4():
|
||||
x = []
|
||||
y = []
|
||||
x.extend(y)
|
||||
return x
|
||||
out = list_net_4()
|
||||
assert np.all(out == ())
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_list_reverse """
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor, ms_function, context
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_reverse_1():
|
||||
"""
|
||||
Feature: list reverse.
|
||||
Description: support list reverse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_1():
|
||||
x = [1, 2, 3, 4]
|
||||
x.reverse()
|
||||
return x
|
||||
out = list_net_1()
|
||||
assert np.all(out == (4, 3, 2, 1))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_reverse_2():
|
||||
"""
|
||||
Feature: list reverse.
|
||||
Description: support list reverse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_2():
|
||||
aa = 20
|
||||
x = ['a', ('bb', '2', 3), aa, 4]
|
||||
x.reverse()
|
||||
return x
|
||||
out = list_net_2()
|
||||
assert np.all(out == (4, 20, ('bb', '2', 3), 'a'))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_reverse_3():
|
||||
"""
|
||||
Feature: list reverse.
|
||||
Description: support list reverse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_3():
|
||||
aa = 20
|
||||
bb = Tensor(1)
|
||||
x = ['a', ('Michael', 'Bob', '2'), aa, 4, bb, (1, 2), Tensor(1)]
|
||||
x.reverse()
|
||||
return x
|
||||
out = list_net_3()
|
||||
assert np.all(out == (Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_list_reverse_4():
|
||||
"""
|
||||
Feature: list reverse.
|
||||
Description: support list reverse.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def list_net_4():
|
||||
x = []
|
||||
x.reverse()
|
||||
return x
|
||||
out = list_net_4()
|
||||
assert np.all(out == ())
|
Loading…
Reference in New Issue