forked from mindspore-Ecosystem/mindspore
!47437 ListExtend supports list, tuple and Tensor
Merge pull request !47437 from huangbingjian/list_extend
This commit is contained in:
commit
6c1607c2ea
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -157,16 +157,24 @@ FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &a
|
|||
}
|
||||
|
||||
FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
abstract::CheckArgsSize("ListExtend", args_list, 2);
|
||||
constexpr size_t list_extend_args_size = 2;
|
||||
abstract::CheckArgsSize("ListExtend", args_list, list_extend_args_size);
|
||||
|
||||
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
||||
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
ret->debug_info()->set_name("extend");
|
||||
|
||||
constexpr size_t current_index = 0;
|
||||
constexpr size_t extend_index = 1;
|
||||
auto abs_current = args_list[current_index];
|
||||
auto abs_extend = args_list[extend_index];
|
||||
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
AddNodeToElems(args_list[0], ret, &elems);
|
||||
AddNodeToElems(args_list[1], ret, &elems);
|
||||
auto abs_current_list = dyn_cast<abstract::AbstractList>(abs_current);
|
||||
MS_EXCEPTION_IF_NULL(abs_current_list);
|
||||
AddNodeToElems(abs_current_list, ret, &elems);
|
||||
AddNodeToElems(abs_extend, ret, &elems);
|
||||
|
||||
auto out = ret->NewCNode(elems);
|
||||
ret->set_output(out);
|
||||
|
@ -174,14 +182,55 @@ FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &
|
|||
}
|
||||
|
||||
void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
|
||||
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg);
|
||||
MS_EXCEPTION_IF_NULL(arg_list);
|
||||
int64_t len = SizeToLong(arg_list->size());
|
||||
AnfNodePtr arg_node = ret->add_parameter();
|
||||
for (int64_t i = 0; i < len; ++i) {
|
||||
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
|
||||
elems->push_back(value);
|
||||
if (arg->isa<abstract::AbstractList>()) {
|
||||
auto arg_list = dyn_cast<abstract::AbstractList>(arg);
|
||||
if (arg_list->dynamic_len()) {
|
||||
MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length list.";
|
||||
}
|
||||
int64_t len = SizeToLong(arg_list->size());
|
||||
for (int64_t i = 0; i < len; ++i) {
|
||||
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
|
||||
elems->push_back(value);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (arg->isa<abstract::AbstractTuple>()) {
|
||||
auto arg_tuple = dyn_cast<abstract::AbstractTuple>(arg);
|
||||
if (arg_tuple->dynamic_len()) {
|
||||
MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length tuple.";
|
||||
}
|
||||
int64_t len = SizeToLong(arg_tuple->size());
|
||||
for (int64_t i = 0; i < len; ++i) {
|
||||
auto value = ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), arg_node, NewValueNode(i)});
|
||||
elems->push_back(value);
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (arg->isa<abstract::AbstractTensor>()) {
|
||||
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(arg);
|
||||
auto shape_ptr = abs_tensor->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(shape_ptr);
|
||||
auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_shape);
|
||||
auto shape = tensor_shape->shape();
|
||||
if (shape.empty()) {
|
||||
MS_LOG(EXCEPTION) << "ListExtend does not support scalar tensor.";
|
||||
}
|
||||
if (shape[0] < 0) {
|
||||
MS_LOG(EXCEPTION) << "ListExtend does not support the tensor whose shapes has an uncertain 0th dimension.";
|
||||
}
|
||||
int64_t len = shape[0];
|
||||
|
||||
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
|
||||
ValuePtr op = prim::GetPythonOps("getitem", module_name);
|
||||
for (int64_t i = 0; i < len; ++i) {
|
||||
auto value = ret->NewCNode({NewValueNode(op), arg_node, NewValueNode(i)});
|
||||
elems->push_back(value);
|
||||
}
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "ListExtend supports list, tuple and Tensor, but got " << arg->ToString();
|
||||
}
|
||||
|
||||
FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import pytest
|
||||
from mindspore.nn import Cell
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" test_list_extend """
|
||||
import numpy as np
|
||||
import mindspore as ms
|
||||
|
||||
|
||||
def test_list_extend_tensor():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def func():
|
||||
x = []
|
||||
y = ms.Tensor([[1, 2], [3, 4]])
|
||||
x.extend(y)
|
||||
return x
|
||||
|
||||
out = func()
|
||||
assert np.all(out[0].asnumpy() == ms.Tensor([1, 2]).asnumpy())
|
||||
assert np.all(out[1].asnumpy() == ms.Tensor([3, 4]).asnumpy())
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
# Copyright 2022-2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -91,3 +91,20 @@ def test_list_extend_4():
|
|||
return x
|
||||
out = list_net_4()
|
||||
assert np.all(out == ())
|
||||
|
||||
|
||||
def test_list_extend_tuple():
|
||||
"""
|
||||
Feature: list extend.
|
||||
Description: support list extend.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def func():
|
||||
x = [1, 2, 3, 4]
|
||||
y = (5, 6, 7)
|
||||
x.extend(y)
|
||||
return x
|
||||
|
||||
out = func()
|
||||
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
|
||||
|
|
Loading…
Reference in New Issue