!49020 Enable list to return in graph (Frontend)

Merge pull request !49020 from LiangZhibo/return
This commit is contained in:
i-robot 2023-03-06 11:41:23 +00:00 committed by Gitee
commit 7bf397ea1c
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
26 changed files with 695 additions and 242 deletions

View File

@ -40,6 +40,21 @@ bool ConvertListToTuple::Run(const FuncGraphPtr &graph) {
<< ",debug name:" << node->DebugString();
}
}
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
for (auto node : graph->parameters()) {
MS_EXCEPTION_IF_NULL(node);
// Convert unused list parameter to tuple.
if (manager->node_users().find(node) != manager->node_users().end()) {
continue;
}
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
if (new_abs != nullptr) {
node->set_abstract(new_abs);
MS_LOG(INFO) << "Convert sequence abstract to tuple abstract for op:" << node->fullname_with_scope()
<< ",debug name:" << node->DebugString();
}
}
return true;
}

View File

@ -613,63 +613,6 @@ class CleanAfterOptARewriter : public BaseRewriter {
~CleanAfterOptARewriter() override = default;
protected:
// From:
// MakeList(arg1, arg2, ...)
// To:
// MakeTuple(arg1, arg2, ...)
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
std::vector<AnfNodePtr> inputs;
inputs.reserve(node->size());
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
// Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
(void)inputs.insert(inputs.cend(), node->inputs().cbegin() + 1, node->inputs().cend());
return node->func_graph()->NewCNode(std::move(inputs));
}
// From:
// list_getitem(list, key)
// To:
// TupleGetItem(list, key)
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
// Inputs should be [list_getitem, list, item]
constexpr size_t expect_input_size = 3;
CheckInputsSize(node, expect_input_size);
constexpr size_t data_index = 1;
constexpr size_t cons_index = 2;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[cons_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
}
// From:
// ListSetItem(list, index, item)
// To:
// TupleSetItem(list, index, item)
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(node->func_graph());
// Inputs should be [list_setitem, list, index, item]
const size_t expect_inputs_size = 4;
CheckInputsSize(node, expect_inputs_size);
const size_t data_index = 1;
const size_t cons_index = 2;
const size_t value_index = 3;
const auto &inputs = node->inputs();
auto &data = inputs[data_index];
auto &key = inputs[cons_index];
auto &value = inputs[value_index];
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
}
// From:
// MakeSparseTensor(indices, values, dense_shape)
// To:
@ -922,9 +865,6 @@ class CleanAfterOptARewriter : public BaseRewriter {
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
static inline const ConverterMap converters_{
{prim::kPrimMakeList, &ThisClass::ConvertMakeListToMakeTuple},
{prim::kPrimListGetItem, &ThisClass::ConvertListGetItemToTupleGetItem},
{prim::kPrimListSetItem, &ThisClass::ConvertListSetItemToTupleSetItem},
// SparseProcess: 1.MakeSparse->MakeTuple 2.SparseGetAttr->TupleGetItem
{prim::kPrimMakeRowTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
{prim::kPrimRowTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
@ -994,39 +934,18 @@ class CleanAfterOptARewriter : public BaseRewriter {
return (this->*(iter->second))(cnode);
}
static ValuePtr ConvertValueSequenceToValueTuple(const ValuePtr &value, size_t depth, bool *need_convert) {
MS_EXCEPTION_IF_NULL(need_convert);
MS_EXCEPTION_IF_NULL(value);
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
if (value->isa<ValueSequence>()) {
std::vector<ValuePtr> elements;
auto value_seq = value->cast<ValueSequencePtr>();
(void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(elements),
[&](const ValuePtr &value) -> ValuePtr {
bool is_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value, depth + 1, &is_convert);
*need_convert |= is_convert;
return convert_value;
});
*need_convert |= value->isa<ValueList>();
if (*need_convert) {
return std::make_shared<ValueTuple>(elements);
}
}
return value;
}
AnfNodePtr ProcessValueSequence(const ValuePtr &value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueSequence>()) {
auto value_seq = value->cast<ValueSequencePtr>();
MS_EXCEPTION_IF_NULL(value_seq);
auto values = value_seq->value();
std::vector<AnfNodePtr> value_seq_inputs{NewValueNode(prim::kPrimMakeTuple)};
std::vector<AnfNodePtr> value_seq_inputs;
if (value_seq->isa<ValueList>()) {
(void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
} else {
(void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
}
for (auto inner_value : values) {
auto inner_value_seq = ProcessValueSequence(inner_value);
(void)value_seq_inputs.emplace_back(inner_value_seq);
@ -1091,60 +1010,14 @@ class CleanAfterOptARewriter : public BaseRewriter {
return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
}
}
bool need_convert = false;
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);
if (need_convert) {
return std::make_shared<ValueNode>(convert_value);
}
return nullptr;
}
// AbstractSequence, AbstractRowTensor --> AbstractTuple.
// AbstractRowTensor --> AbstractTuple.
static AbstractBasePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
if (depth > kMaxSeqRecursiveDepth) {
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
}
// AbstractList --> AbstractTuple.
auto abs_seq = abs->cast<AbstractSequencePtr>();
if (abs_seq != nullptr) {
// Dynamic length sequence do not convert.
if (abs_seq->dynamic_len() && abs_seq->isa<AbstractList>()) {
auto converted_abs_tuple = std::make_shared<AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
converted_abs_tuple->set_dynamic_len(true);
converted_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
return converted_abs_tuple;
}
const auto &seq_elements = abs_seq->elements();
// First we check if elements should be converted,
// changed_elements maps old element to new element.
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
for (const auto &element : seq_elements) {
auto new_element = ConvertToAbstractTuple(element, depth + 1);
if (new_element != nullptr) {
(void)changed_elements.emplace(element, new_element);
}
}
if (changed_elements.empty()) {
if (abs->isa<AbstractTuple>()) {
// If no elements changed and it is an AbstractTuple, do not convert.
return nullptr;
}
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
return std::make_shared<AbstractTuple>(seq_elements);
}
// Always make new AbstractTuple when elements changed.
std::vector<AbstractBasePtr> elements;
elements.reserve(seq_elements.size());
for (const auto &element : seq_elements) {
auto iter = changed_elements.find(element);
if (iter != changed_elements.end()) {
(void)elements.emplace_back(iter->second);
} else {
(void)elements.emplace_back(element);
}
}
return std::make_shared<AbstractTuple>(std::move(elements));
}
// AbstractRowTensor --> AbstractTuple.
auto abs_row_tensor = abs->cast<std::shared_ptr<AbstractRowTensor>>();
if (abs_row_tensor != nullptr) {

View File

@ -502,6 +502,49 @@ py::object VectorToPyData(const Any &value) {
}
return ret;
}
template <typename T>
py::object AbstractSequenceToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
auto value_size = value_list.size();
auto ret = T(value_size);
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
bool dynamic_len = seq_abs->dynamic_len();
auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs();
if (dynamic_len || dynamic_len_element_abs != nullptr) {
if (dynamic_len_element_abs == nullptr) {
MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple.";
for (size_t i = 0; i < value_size; i++) {
ret[i] = BaseRefToPyData(value_list[i]);
}
return ret;
}
if (dynamic_len_element_abs->isa<abstract::AbstractNone>()) {
MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty sequence.";
return ret;
}
for (size_t i = 0; i < value_size; ++i) {
ret[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs);
}
return ret;
}
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
// If FALLBACK_RUNTIME is not enable
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
size_t ref_idx = 0;
for (size_t i = 0; i < seq_abs->size(); i++) {
auto elem_abs = seq_abs->elements()[i];
if (elem_abs->isa<abstract::AbstractNone>() && !support_fallback_runtime) {
continue;
}
ret[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
ref_idx++;
}
if (ref_idx != value_size) {
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
<< ref_idx;
}
return ret;
}
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
py::object ret;
@ -516,55 +559,16 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
}
// Current VectorRef reflects a COOTensor type
MS_LOG(DEBUG) << "abs: " << abs->ToString();
if (abs->isa<abstract::AbstractCSRTensor>()) {
return MakeCSRTensor(value_list);
}
if (abs->isa<abstract::AbstractCOOTensor>()) {
return MakeCOOTensor(value_list);
}
if (!abs->isa<abstract::AbstractSequence>()) {
return VectorRefToPyData(value_list, nullptr);
if (abs->isa<abstract::AbstractList>()) {
return AbstractSequenceToPyData<py::list>(value_list, abs);
}
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
MS_EXCEPTION_IF_NULL(seq_abs);
bool dynamic_len = seq_abs->dynamic_len();
auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs();
if (dynamic_len || dynamic_len_element_abs != nullptr) {
if (dynamic_len_element_abs == nullptr) {
MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple.";
for (size_t i = 0; i < value_size; i++) {
ref_tuple[i] = BaseRefToPyData(value_list[i]);
}
return ref_tuple;
}
if (dynamic_len_element_abs->isa<abstract::AbstractNone>()) {
MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty tuple.";
return ref_tuple;
}
for (size_t i = 0; i < value_size; ++i) {
ref_tuple[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs);
}
return ref_tuple;
}
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
// If FALLBACK_RUNTIME is not enable
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
size_t ref_idx = 0;
for (size_t i = 0; i < seq_abs->size(); i++) {
auto elem_abs = seq_abs->elements()[i];
if (elem_abs->isa<abstract::AbstractNone>() && !support_fallback_runtime) {
continue;
}
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
ref_idx++;
}
if (ref_idx != value_size) {
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
<< ref_idx;
}
ret = ref_tuple;
return ret;
return AbstractSequenceToPyData<py::tuple>(value_list, abs);
}
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,

View File

@ -0,0 +1,558 @@
# Copyright 2023 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Test return list type object from graph"""
import os
import pytest
import numpy as np
import mindspore.common.dtype as mstype
from mindspore import ops
from mindspore.common import mutable
from mindspore import Tensor, jit, context
context.set_context(mode=context.GRAPH_MODE)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [1, 2, 3, 4]
res = foo()
assert res == [1, 2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_2():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return ["a", "b", "c", "d"]
res = foo()
assert res == ["a", "b", "c", "d"]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_3():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [True, False, False, True]
res = foo()
assert res == [True, False, False, True]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_4():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [Tensor([1]), Tensor([1, 2, 3]), Tensor([2, 3])]
res = foo()
assert len(res) == 3
assert np.all(res[0].asnumpy() == np.array([1]))
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
assert np.all(res[2].asnumpy() == np.array([2, 3]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_5():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [None, None, None]
res = foo()
assert res == [None, None, None]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_6():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [np.array([1, 2, 3]), np.array([4, 5, 6]), 1]
res = foo()
assert isinstance(res, list)
assert len(res) == 3
assert np.all(res[0] == np.array([1, 2, 3]))
assert np.all(res[1] == np.array([4, 5, 6]))
assert res[2] == 1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_constant_list_7():
"""
Feature: Return list in graph
Description: Support return constant list.
Expectation: No exception.
"""
@jit
def foo():
return [1, "a", True, None, Tensor([2])]
res = foo()
assert res == [1, "a", True, None, Tensor([2])]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_node():
"""
Feature: Return list in graph
Description: Support return make list node.
Expectation: No exception.
"""
os.environ["GRAPH_OP_RUN"] = "1"
@jit
def foo(x):
return [x, x+1, x+2, 1]
res = foo(mutable(1))
assert res == [1, 2, 3, 1]
os.environ["GRAPH_OP_RUN"] = "0"
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_node_2():
"""
Feature: Return list in graph
Description: Support return make list node.
Expectation: No exception.
"""
@jit
def foo(x):
return [x, x+1, x+2, Tensor([4])]
res = foo(Tensor([1]))
assert res == [Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_node_3():
"""
Feature: Return list in graph
Description: Support return make list node.
Expectation: No exception.
"""
@jit
def foo(x):
return [x, mutable(1), "a"]
res = foo(Tensor([1]))
assert res == [Tensor([1]), 1, "a"]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_with_nest():
"""
Feature: Return list in graph
Description: Support return make list in nest scene.
Expectation: No exception.
"""
@jit
def foo():
return [[1, 2, 3], [4, 5, 6]]
res = foo()
assert res == [[1, 2, 3], [4, 5, 6]]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_with_nest_2():
"""
Feature: Return list in graph
Description: Support return make list in nest scene.
Expectation: No exception.
"""
@jit
def foo():
return [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]]
res = foo()
assert res == [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_with_nest_3():
"""
Feature: Return list in graph
Description: Support return make list in nest scene.
Expectation: No exception.
"""
@jit
def foo():
return (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6])
res = foo()
assert res == (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6])
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_with_nest():
"""
Feature: Return list in graph
Description: Support return make list in nest scene.
Expectation: No exception.
"""
@jit
def foo(x):
return [[x, x], (x+1, x+2)]
res = foo(Tensor([0]))
assert res == [[Tensor([0]), Tensor([0])], (Tensor([1]), Tensor([2]))]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_make_list_with_nest_2():
"""
Feature: Return list in graph
Description: Support return make list in nest scene.
Expectation: No exception.
"""
@jit
def foo(x):
return [x, ([x, 1],)], (x+1, x+2)
res = foo(Tensor([0]))
assert res == ([Tensor([0]), ([Tensor([0]), 1],)], (Tensor([1]), Tensor([2])))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_buildin_list_func():
"""
Feature: Return list in graph
Description: Support return result of list() function.
Expectation: No exception.
"""
@jit
def foo():
return list((1, "2", None, Tensor([1])))
res = foo()
assert res == [1, "2", None, Tensor([1])]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_buildin_list_func_2():
"""
Feature: Return list in graph
Description: Support return result of list() function.
Expectation: No exception.
"""
@jit
def foo(x):
return list(x)
res = foo(Tensor([1, 2, 3]))
assert res == [Tensor([1]), Tensor([2]), Tensor([3])]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_dynamic_length_list():
"""
Feature: Return list in graph
Description: Support return dynamic length list.
Expectation: No exception.
"""
@jit
def foo():
x = mutable([1, 2, 3], True)
return x
res = foo()
assert res == [1, 2, 3]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_dynamic_length_list_2():
"""
Feature: Return list in graph
Description: Support return dynamic length list.
Expectation: No exception.
"""
@jit
def foo(m):
x = mutable([m, m+1], True)
return x
res = foo(Tensor([0]))
assert res == [Tensor([0]), Tensor([1])]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_from_third_party():
"""
Feature: Return list in graph
Description: Support return list from third party.
Expectation: No exception.
"""
@jit
def foo():
m = np.array([1, 2, 3, 4])
x = m.tolist()
return x
res = foo()
assert res == [1, 2, 3, 4]
@pytest.mark.skip(reason="Getattr for interpret node failed.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_from_third_party_2():
"""
Feature: Return list in graph
Description: Support return list from third party.
Expectation: No exception.
"""
@jit
def foo(m):
x = m.asnumpy().tolist()
return x
res = foo(Tensor([1, 2, 3, 4]))
assert res == [1, 2, 3, 4]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_from_third_party_3():
"""
Feature: Return list in graph
Description: Support return list from third party.
Expectation: No exception.
"""
@jit
def foo():
x = np.arange(0, 10, 2)
return list(x)
res = foo()
assert res == [0, 2, 4, 6, 8]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_from_dict_attribute():
"""
Feature: Return list in graph
Description: Support return list from dict keys and values.
Expectation: No exception.
"""
@jit
def foo(x, y):
m = {"1": x, "2": y}
return list(m.keys()), list(m.values())
res = foo(Tensor([1]), mutable(2))
assert len(res) == 2
assert res[0] == ["1", "2"]
assert res[1] == [Tensor([1]), 2]
@pytest.mark.skip(reason="Return dict change the abstract.")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_return_list_from_dict_attribute_2():
"""
Feature: Return list in graph
Description: Support return list from dict keys and values.
Expectation: No exception.
"""
@jit
def foo(x, y):
m = {"1": x, "2": y}
return m, list(m.keys()), list(m.values())
res = foo(Tensor([1]), mutable(2))
assert len(res) == 3
assert res[1] == ["1", "2"]
assert res[2] == [Tensor([1]), 2]
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_for_return_list_graph():
"""
Feature: Return list in graph
Description: Support calculate gradient for graph with list return.
Expectation: No exception.
"""
@jit
def foo(x):
y = ops.ReLU()(x)
return [y,]
x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
res = ops.grad(foo)(x)
assert np.allclose(res.asnumpy(), np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).astype(np.float32))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_grad_for_graph_with_list_input():
"""
Feature: Return list in graph
Description: Support calculate gradient for graph with list return.
Expectation: No exception.
"""
@jit
def foo(t):
x = t[0]
y = t[1]
out = ops.MatMul()(x, y)
return out
t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = ops.grad(foo)(t)
assert isinstance(output, list)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert np.allclose(output[0].asnumpy(), expect[0])
assert np.allclose(output[1].asnumpy(), expect[1])

View File

@ -37,7 +37,7 @@ def test_fallback_list_with_input_constant_tensor():
x.append(Tensor([4]))
return x
out = foo()
assert isinstance(out, tuple)
assert isinstance(out, list)
assert len(out) == 4
assert isinstance(out[0], Tensor)
assert out[0].asnumpy() == 1
@ -66,7 +66,7 @@ def test_fallback_list_with_input_constant_tensor_2():
x.append(Tensor([5, 6]))
return x
out = foo()
assert isinstance(out, tuple)
assert isinstance(out, list)
assert len(out) == 3
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
@ -139,7 +139,7 @@ def test_fallback_tuple_with_input_constant_tensor_2():
x = list(Tensor([[1, 2], [3, 4]]))
return x
out = foo()
assert isinstance(out, tuple)
assert isinstance(out, list)
assert len(out) == 2
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))

View File

@ -269,6 +269,7 @@ def test_raise_with_variable_control_flow3():
raise_info_joinedstr_tensor.value)
@pytest.mark.skip(reason="not support yet")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard

View File

@ -119,7 +119,7 @@ def test_grad_mutable_list_tensor():
t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
@ -310,7 +310,7 @@ def test_grad_mutable_list_tuple_tensor():
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
[0, 0, 0]]).astype(np.float32)],
@ -423,6 +423,7 @@ def test_grad_mutable_dict_tuple_tensor():
assert compare(output, expect)
@pytest.mark.skip(reason="Do not support yet")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -467,7 +468,7 @@ def test_grad_mutable_list_dict_tensor():
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert isinstance(output, tuple)
assert isinstance(output, list)
assert len(output) == 2
assert isinstance(output[0], dict)
assert len(output[0].keys()) == 2
@ -579,7 +580,7 @@ def test_grad_mutable_list_tensor_jit_function():
context.set_context(mode=context.GRAPH_MODE)
output = GradOperation()(net)(z)
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
@ -673,7 +674,7 @@ def test_grad_mutable_unused_list_tensor():
Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32),
Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)])
output = GradNetWrtX(Net())(t)
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [np.array([[3., 3., 3.],
[3., 3., 3.]]).astype(np.float32),
np.array([[0., 0., 0.],

View File

@ -279,7 +279,7 @@ def test_grad_const_list_tensor_to_mutable():
grad_net = GradNetWrtX(Net())
output = grad_net()
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
@ -288,7 +288,7 @@ def test_grad_const_list_tensor_to_mutable():
assert compare(output, expect)
grad_net = GradNetWrtX1(Net())
output = grad_net()
assert isinstance(output, tuple)
assert isinstance(output, list)
assert compare(output, expect)
@ -338,7 +338,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable():
x = [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]
output = grad_net(x)
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[1.7, 1.7, 1.7],
@ -347,6 +347,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable():
assert compare(output, expect)
@pytest.mark.skip(reason="not support yet.")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
@ -397,7 +398,7 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
grad_net = GradNetWrtX(Net())
output = grad_net()
assert isinstance(output, tuple)
assert isinstance(output, list)
expect = [(np.array([[1.4100001, 1.5999999, 6.6],
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
np.array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
@ -405,10 +406,10 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
np.array([[1.7, 1.7, 1.7],
[1.9, 1.9, 1.9],
[1.5, 1.5, 1.5]]).astype(np.float32)]
assert compare(output, expect)
assert compare(output, list)
grad_net = GradNetWrtX1(Net())
output = grad_net()
assert isinstance(output, tuple)
assert isinstance(output, list)
assert compare(output, expect)

View File

@ -53,7 +53,7 @@ def test_value_node_with_depend():
x = [[1, 2, 3, 4], [5, 6, 7, 8]]
net = NetValueNodeWithDepend()
output = net(x)
assert output == (5, 0, 7, 8)
assert output == [5, 0, 7, 8]
@pytest.mark.level0

View File

@ -41,7 +41,7 @@ def test_single_if_no_else_type():
test_net = FalseNet()
res = test_net()
assert str(res) == "(<class 'numpy.ndarray'>, <class 'object'>)"
assert str(res) == "[<class 'numpy.ndarray'>, <class 'object'>]"
def test_single_if_no_else_type_2():
@ -64,7 +64,7 @@ def test_single_if_no_else_type_2():
test_net = TrueNet()
res = test_net()
assert str(res) == "(<class 'int'>, <class 'object'>)"
assert str(res) == "[<class 'int'>, <class 'object'>]"
def test_single_if_1():

View File

@ -61,7 +61,7 @@ def test_dict1():
input_me = Tensor(input_np)
net = Net1()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, 1)
assert out_me == ['x', 'y', 0, 1]
def test_dict2():
@ -76,7 +76,7 @@ def test_dict3():
input_me = Tensor(input_np)
net = Net3()
out_me = net(input_me)
assert out_me == ('x', 'y', 0, (0, 1))
assert out_me == ['x', 'y', 0, (0, 1)]
def test_dict4():

View File

@ -576,7 +576,7 @@ def test_list_double_slice():
class NetInner(Cell):
def construct(self, a, b, start1, stop1, step1, start2, stop2, step2):
a[start1:stop1:step1][start2: stop2: step2] = b
return tuple(a)
return a
net = NetInner()
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
@ -728,7 +728,7 @@ def test_list_slice_negetive_step():
a = [1, 2, 3, 4, 5]
b = [11, 22, 33, 44, 55]
a[-1:-4:-1] = b[-1:-4:-1]
return tuple(a)
return a
x = py_func()
y = ms_func()
@ -771,6 +771,6 @@ def test_list_slice_only_with_step():
a = [1, 2, 3, 4]
b = [11, 22]
a[::2] = b
return tuple(a)
return a
assert ms_func() == py_func()

View File

@ -33,7 +33,7 @@ def test_list_extend_1():
x.extend(y)
return x
out = list_net_1()
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
assert np.all(out == [1, 2, 3, 4, 5, 6, 7])
def test_list_extend_2():
@ -52,7 +52,7 @@ def test_list_extend_2():
x.extend(z)
return x
out = list_net_2()
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 20))
assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 20])
def test_list_extend_3():
@ -73,8 +73,8 @@ def test_list_extend_3():
x.extend(z)
return x
out = list_net_3()
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
20, 4, Tensor(1), (1, 2), Tensor(1)))
assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
20, 4, Tensor(1), (1, 2), Tensor(1)])
def test_list_extend_4():
@ -90,7 +90,7 @@ def test_list_extend_4():
x.extend(y)
return x
out = list_net_4()
assert np.all(out == ())
assert np.all(out == [])
def test_list_extend_tuple():
@ -107,4 +107,4 @@ def test_list_extend_tuple():
return x
out = func()
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
assert np.all(out == [1, 2, 3, 4, 5, 6, 7])

View File

@ -132,7 +132,7 @@ def test_list_insert_pop_2():
return x, y
res_x, res_y = list_insert_pop(-2)
assert np.all(res_x == (3, 1, 4))
assert np.all(res_x == [3, 1, 4])
assert res_y == 3

View File

@ -40,9 +40,9 @@ def test_list_mul_number():
"""
net = Net()
expect_ret0 = [Tensor([1, 2, 3])] * 5
expect_ret1 = (Tensor([1, 2, 3]),) * 0
assert isinstance(net()[0], tuple)
assert isinstance(net()[1], tuple)
expect_ret1 = [Tensor([1, 2, 3]),] * 0
assert isinstance(net()[0], list)
assert isinstance(net()[1], list)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

View File

@ -33,7 +33,7 @@ def test_list_pop_1():
return x, y
res_x, res_y = list_pop()
assert np.all(res_x == (1, 2, 3))
assert np.all(res_x == [1, 2, 3])
assert res_y == 4
@ -50,7 +50,7 @@ def test_list_pop_2():
return x, y
res_x, res_y = list_pop()
assert np.all(res_x == (1, 2, 4))
assert np.all(res_x == [1, 2, 4])
assert res_y == 3
@ -67,7 +67,7 @@ def test_list_pop_3():
return x, y
res_x, res_y = list_pop()
assert np.all(res_x == (1, 3, 4))
assert np.all(res_x == [1, 3, 4])
assert res_y == 2
@ -140,8 +140,8 @@ def test_list_pop_7():
return x1, x2, y1 + y2
res_x1, res_x2, res_y = list_pop()
assert np.all(res_x1 == (1, 3, 4))
assert np.all(res_x2 == (5, 6, 8))
assert np.all(res_x1 == [1, 3, 4])
assert np.all(res_x2 == [5, 6, 8])
assert res_y == 9
@ -158,7 +158,7 @@ def test_list_pop_8():
return x, y
res_x, res_y = list_pop(2)
assert res_x == (Tensor([1]), Tensor([2]))
assert res_x == [Tensor([1]), Tensor([2])]
assert res_y == Tensor([3])
@ -175,7 +175,7 @@ def test_list_pop_9():
input_x = [Tensor([1]), Tensor([2]), Tensor([3])]
res_x, res_y = list_pop(input_x, 2)
assert res_x == (Tensor([1]), Tensor([2]))
assert res_x == [Tensor([1]), Tensor([2])]
assert res_y == Tensor([3])

View File

@ -32,7 +32,7 @@ def test_list_reverse_1():
x.reverse()
return x
out = list_net_1()
assert np.all(out == (4, 3, 2, 1))
assert np.all(out == [4, 3, 2, 1])
def test_list_reverse_2():
@ -48,7 +48,7 @@ def test_list_reverse_2():
x.reverse()
return x
out = list_net_2()
assert np.all(out == (4, 20, ('bb', '2', 3), 'a'))
assert np.all(out == [4, 20, ('bb', '2', 3), 'a'])
def test_list_reverse_3():
@ -65,7 +65,7 @@ def test_list_reverse_3():
x.reverse()
return x
out = list_net_3()
assert np.all(out == (Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'))
assert np.all(out == [Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'])
def test_list_reverse_4():
@ -80,4 +80,4 @@ def test_list_reverse_4():
x.reverse()
return x
out = list_net_4()
assert np.all(out == ())
assert np.all(out == [])

View File

@ -39,9 +39,9 @@ def test_number_mul_list():
context.set_context(mode=context.GRAPH_MODE)
net = Net()
expect_ret0 = 5 * [Tensor([1, 2, 3])]
expect_ret1 = 0 * (Tensor([1, 2, 3]),)
assert isinstance(net()[0], tuple)
assert isinstance(net()[1], tuple)
expect_ret1 = 0 * [Tensor([1, 2, 3]),]
assert isinstance(net()[0], list)
assert isinstance(net()[1], list)
for i in range(len(net()[0])):
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
assert net()[1] == expect_ret1

View File

@ -30,7 +30,7 @@ def test_fallback_dict_empty():
dict_x['a'] = [1, 2, 3]
return dict_x["a"]
assert foo() == (1, 2, 3)
assert foo() == [1, 2, 3]
def test_fallback_dict_zip_iter_assign():

View File

@ -137,7 +137,7 @@ def test_fallback_reversed():
def foo():
x = reversed([1, 2, 3])
return list(x)
assert foo() == (3, 2, 1)
assert foo() == [3, 2, 1]
def test_fallback_set():

View File

@ -35,8 +35,8 @@ def test_fallback_list_with_input_tuple():
x.append(4)
return x
out = foo()
assert isinstance(out, tuple)
assert operator.eq(out, (1, 2, 3, 4))
assert isinstance(out, list)
assert operator.eq(out, [1, 2, 3, 4])
def test_fallback_list_with_input_list():
@ -51,8 +51,8 @@ def test_fallback_list_with_input_list():
x.append(4)
return x
out = foo()
assert isinstance(out, tuple)
assert operator.eq(out, (1, 2, 3, 4))
assert isinstance(out, list)
assert operator.eq(out, [1, 2, 3, 4])
def test_fallback_list_with_input_dict():
@ -67,8 +67,8 @@ def test_fallback_list_with_input_dict():
x.append('d')
return x
out = foo()
assert isinstance(out, tuple)
assert operator.eq(out, ('a', 'b', 'c', 'd'))
assert isinstance(out, list)
assert operator.eq(out, ['a', 'b', 'c', 'd'])
def test_fallback_list_with_input_numpy_array():

View File

@ -131,7 +131,7 @@ def test_fallback_max_with_two_inputs_list():
x = max([1, 2, 3], [4, 5])
return x
out = foo()
assert operator.eq(out, (4, 5))
assert operator.eq(out, [4, 5])
def test_fallback_min_with_two_inputs_list():
@ -145,7 +145,7 @@ def test_fallback_min_with_two_inputs_list():
x = min([1, 2, 3], [4, 5])
return x
out = foo()
assert operator.eq(out, (1, 2, 3))
assert operator.eq(out, [1, 2, 3])
def test_builtin_function_max_min_with_string():

View File

@ -53,21 +53,21 @@ def get_list_comp_5():
@jit
def get_generator_exp_1():
t = (x for x in range(1, 6))
return t
return tuple(t)
@jit
def get_generator_exp_2():
t = (x * x for x in range(1, 11) if x > 5 if x % 2 == 0)
return t
return tuple(t)
def test_list_comp():
context.set_context(mode=context.GRAPH_MODE)
assert get_list_comp_1() == (1, 2, 3, 4, 5)
assert get_list_comp_2() == (1, 4, 9, 16, 25)
assert get_list_comp_3() == (4, 16, 36, 64, 100)
assert get_list_comp_4() == (36, 64, 100)
assert get_list_comp_1() == [1, 2, 3, 4, 5]
assert get_list_comp_2() == [1, 4, 9, 16, 25]
assert get_list_comp_3() == [4, 16, 36, 64, 100]
assert get_list_comp_4() == [36, 64, 100]
with pytest.raises(TypeError) as ex:
get_list_comp_5()
assert "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp" in str(ex.value)

View File

@ -41,4 +41,4 @@ class NetWork(Cell):
list1 = [1, 2, 3]
net1 = NetWork()
result = net1(list1)
assert result == (1, 3)
assert result == [1, 3]

View File

@ -58,7 +58,7 @@ def test_hypermap_value():
return map(lambda a: a + 1, self._list)
net = Net()
assert net() == (23, 67, 89, 112)
assert net() == [23, 67, 89, 112]
def test_hypermap_func_const():
@ -80,4 +80,4 @@ def test_hypermap_func_const():
return map(lambda f: f(4), _list)
net = NetMap()
assert net() == (8, 12, 16)
assert net() == [8, 12, 16]

View File

@ -66,7 +66,7 @@ def test_output_const_list():
return ret
net = Net()
assert net() == (1, 2, 3)
assert net() == [1, 2, 3]
def test_output_const_int():