!47437 ListExtend supports list, tuple and Tensor

Merge pull request !47437 from huangbingjian/list_extend
This commit is contained in:
i-robot 2023-01-09 06:09:41 +00:00 committed by Gitee
commit 6c1607c2ea
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
4 changed files with 126 additions and 11 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2019-2022 Huawei Technologies Co., Ltd
* Copyright 2019-2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -157,16 +157,24 @@ FuncGraphPtr ListClear::GenerateFuncGraph(const abstract::AbstractBasePtrList &a
}
FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
abstract::CheckArgsSize("ListExtend", args_list, 2);
constexpr size_t list_extend_args_size = 2;
abstract::CheckArgsSize("ListExtend", args_list, list_extend_args_size);
FuncGraphPtr ret = std::make_shared<FuncGraph>();
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
ret->debug_info()->set_name("extend");
constexpr size_t current_index = 0;
constexpr size_t extend_index = 1;
auto abs_current = args_list[current_index];
auto abs_extend = args_list[extend_index];
std::vector<AnfNodePtr> elems;
elems.push_back(NewValueNode(prim::kPrimMakeList));
AddNodeToElems(args_list[0], ret, &elems);
AddNodeToElems(args_list[1], ret, &elems);
auto abs_current_list = dyn_cast<abstract::AbstractList>(abs_current);
MS_EXCEPTION_IF_NULL(abs_current_list);
AddNodeToElems(abs_current_list, ret, &elems);
AddNodeToElems(abs_extend, ret, &elems);
auto out = ret->NewCNode(elems);
ret->set_output(out);
@ -174,14 +182,55 @@ FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &
}
void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) {
abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg);
MS_EXCEPTION_IF_NULL(arg_list);
int64_t len = SizeToLong(arg_list->size());
AnfNodePtr arg_node = ret->add_parameter();
for (int64_t i = 0; i < len; ++i) {
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
elems->push_back(value);
if (arg->isa<abstract::AbstractList>()) {
auto arg_list = dyn_cast<abstract::AbstractList>(arg);
if (arg_list->dynamic_len()) {
MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length list.";
}
int64_t len = SizeToLong(arg_list->size());
for (int64_t i = 0; i < len; ++i) {
auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
elems->push_back(value);
}
return;
}
if (arg->isa<abstract::AbstractTuple>()) {
auto arg_tuple = dyn_cast<abstract::AbstractTuple>(arg);
if (arg_tuple->dynamic_len()) {
MS_LOG(EXCEPTION) << "ListExtend does not support dynamic length tuple.";
}
int64_t len = SizeToLong(arg_tuple->size());
for (int64_t i = 0; i < len; ++i) {
auto value = ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), arg_node, NewValueNode(i)});
elems->push_back(value);
}
return;
}
if (arg->isa<abstract::AbstractTensor>()) {
auto abs_tensor = dyn_cast<abstract::AbstractTensor>(arg);
auto shape_ptr = abs_tensor->BuildShape();
MS_EXCEPTION_IF_NULL(shape_ptr);
auto tensor_shape = shape_ptr->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(tensor_shape);
auto shape = tensor_shape->shape();
if (shape.empty()) {
MS_LOG(EXCEPTION) << "ListExtend does not support scalar tensor.";
}
if (shape[0] < 0) {
MS_LOG(EXCEPTION) << "ListExtend does not support the tensor whose shapes has an uncertain 0th dimension.";
}
int64_t len = shape[0];
std::string module_name = "mindspore.ops.composite.multitype_ops.getitem_impl";
ValuePtr op = prim::GetPythonOps("getitem", module_name);
for (int64_t i = 0; i < len; ++i) {
auto value = ret->NewCNode({NewValueNode(op), arg_node, NewValueNode(i)});
elems->push_back(value);
}
return;
}
MS_LOG(EXCEPTION) << "ListExtend supports list, tuple and Tensor, but got " << arg->ToString();
}
FuncGraphPtr ListReverse::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {

View File

@ -1,3 +1,17 @@
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
from mindspore.nn import Cell

View File

@ -0,0 +1,35 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test_list_extend """
import numpy as np
import mindspore as ms
def test_list_extend_tensor():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@ms.jit
def func():
x = []
y = ms.Tensor([[1, 2], [3, 4]])
x.extend(y)
return x
out = func()
assert np.all(out[0].asnumpy() == ms.Tensor([1, 2]).asnumpy())
assert np.all(out[1].asnumpy() == ms.Tensor([3, 4]).asnumpy())

View File

@ -1,4 +1,4 @@
# Copyright 2022 Huawei Technologies Co., Ltd
# Copyright 2022-2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -91,3 +91,20 @@ def test_list_extend_4():
return x
out = list_net_4()
assert np.all(out == ())
def test_list_extend_tuple():
"""
Feature: list extend.
Description: support list extend.
Expectation: No exception.
"""
@jit
def func():
x = [1, 2, 3, 4]
y = (5, 6, 7)
x.extend(y)
return x
out = func()
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))