From c30a1c70180c7cd94c01289e2a66199f1a55133a Mon Sep 17 00:00:00 2001 From: liangzhibo Date: Thu, 16 Feb 2023 19:45:42 +0800 Subject: [PATCH] Enable list as output --- .../common/pass/convert_list_to_tuple.cc | 15 + .../frontend/optimizer/fallback_rewriter.cc | 141 +---- mindspore/ccsrc/utils/convert_utils_py.cc | 88 +-- .../test_graph_fallback_return_list.py | 558 ++++++++++++++++++ .../test_list_tuple.py | 6 +- .../test_graph_raise_with_variable.py | 1 + tests/st/mutable/test_grad_mutable.py | 11 +- tests/st/mutable/test_mutable_in_graph.py | 13 +- tests/st/runtime/test_runtime_output.py | 2 +- .../test_fallback_000_single_if.py | 4 +- .../graph_syntax/dict/test_dictionary.py | 4 +- .../graph_syntax/list/test_list_assign.py | 6 +- .../graph_syntax/list/test_list_extend.py | 12 +- .../graph_syntax/list/test_list_insert.py | 2 +- .../graph_syntax/list/test_list_mul_number.py | 6 +- .../python/graph_syntax/list/test_list_pop.py | 14 +- .../graph_syntax/list/test_list_reverse.py | 8 +- .../graph_syntax/list/test_number_mul_list.py | 6 +- .../python_builtin_functions/test_dict.py | 2 +- .../test_graph_fallback_python_builtin.py | 2 +- .../test_list_tuple.py | 12 +- .../python_builtin_functions/test_max_min.py | 4 +- .../graph_syntax/statements/test_list_comp.py | 12 +- tests/ut/python/ops/test_filter.py | 2 +- .../ut/python/pipeline/infer/test_hypermap.py | 4 +- .../pipeline/parse/test_structure_output.py | 2 +- 26 files changed, 695 insertions(+), 242 deletions(-) create mode 100644 tests/st/fallback/test_graph_fallback_return_list.py diff --git a/mindspore/ccsrc/backend/common/pass/convert_list_to_tuple.cc b/mindspore/ccsrc/backend/common/pass/convert_list_to_tuple.cc index cd1172302f5..ee145970f31 100644 --- a/mindspore/ccsrc/backend/common/pass/convert_list_to_tuple.cc +++ b/mindspore/ccsrc/backend/common/pass/convert_list_to_tuple.cc @@ -40,6 +40,21 @@ bool ConvertListToTuple::Run(const FuncGraphPtr &graph) { << ",debug name:" << node->DebugString(); } } + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (auto node : graph->parameters()) { + MS_EXCEPTION_IF_NULL(node); + // Convert unused list parameter to tuple. + if (manager->node_users().find(node) != manager->node_users().end()) { + continue; + } + auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract()); + if (new_abs != nullptr) { + node->set_abstract(new_abs); + MS_LOG(INFO) << "Convert sequence abstract to tuple abstract for op:" << node->fullname_with_scope() + << ",debug name:" << node->DebugString(); + } + } return true; } diff --git a/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc b/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc index 540a64adf99..172a41f14d7 100644 --- a/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc +++ b/mindspore/ccsrc/frontend/optimizer/fallback_rewriter.cc @@ -613,63 +613,6 @@ class CleanAfterOptARewriter : public BaseRewriter { ~CleanAfterOptARewriter() override = default; protected: - // From: - // MakeList(arg1, arg2, ...) - // To: - // MakeTuple(arg1, arg2, ...) - AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - std::vector inputs; - inputs.reserve(node->size()); - (void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); - // Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items; - (void)inputs.insert(inputs.cend(), node->inputs().cbegin() + 1, node->inputs().cend()); - return node->func_graph()->NewCNode(std::move(inputs)); - } - - // From: - // list_getitem(list, key) - // To: - // TupleGetItem(list, key) - AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [list_getitem, list, item] - constexpr size_t expect_input_size = 3; - CheckInputsSize(node, expect_input_size); - constexpr size_t data_index = 1; - constexpr size_t cons_index = 2; - const auto &inputs = node->inputs(); - auto &data = inputs[data_index]; - auto &key = inputs[cons_index]; - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key}); - } - - // From: - // ListSetItem(list, index, item) - // To: - // TupleSetItem(list, index, item) - AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(node->func_graph()); - - // Inputs should be [list_setitem, list, index, item] - const size_t expect_inputs_size = 4; - CheckInputsSize(node, expect_inputs_size); - - const size_t data_index = 1; - const size_t cons_index = 2; - const size_t value_index = 3; - const auto &inputs = node->inputs(); - auto &data = inputs[data_index]; - auto &key = inputs[cons_index]; - auto &value = inputs[value_index]; - return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value}); - } - // From: // MakeSparseTensor(indices, values, dense_shape) // To: @@ -922,9 +865,6 @@ class CleanAfterOptARewriter : public BaseRewriter { using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &); using ConverterMap = mindspore::HashMap; static inline const ConverterMap converters_{ - {prim::kPrimMakeList, &ThisClass::ConvertMakeListToMakeTuple}, - {prim::kPrimListGetItem, &ThisClass::ConvertListGetItemToTupleGetItem}, - {prim::kPrimListSetItem, &ThisClass::ConvertListSetItemToTupleSetItem}, // SparseProcess: 1.MakeSparse->MakeTuple 2.SparseGetAttr->TupleGetItem {prim::kPrimMakeRowTensor, &ThisClass::ConvertMakeSparseToMakeTuple}, {prim::kPrimRowTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem}, @@ -994,39 +934,18 @@ class CleanAfterOptARewriter : public BaseRewriter { return (this->*(iter->second))(cnode); } - static ValuePtr ConvertValueSequenceToValueTuple(const ValuePtr &value, size_t depth, bool *need_convert) { - MS_EXCEPTION_IF_NULL(need_convert); - MS_EXCEPTION_IF_NULL(value); - if (depth > kMaxSeqRecursiveDepth) { - MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels."; - } - - if (value->isa()) { - std::vector elements; - auto value_seq = value->cast(); - (void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(elements), - [&](const ValuePtr &value) -> ValuePtr { - bool is_convert = false; - auto convert_value = ConvertValueSequenceToValueTuple(value, depth + 1, &is_convert); - *need_convert |= is_convert; - return convert_value; - }); - *need_convert |= value->isa(); - if (*need_convert) { - return std::make_shared(elements); - } - } - - return value; - } - AnfNodePtr ProcessValueSequence(const ValuePtr &value) { MS_EXCEPTION_IF_NULL(value); if (value->isa()) { auto value_seq = value->cast(); MS_EXCEPTION_IF_NULL(value_seq); auto values = value_seq->value(); - std::vector value_seq_inputs{NewValueNode(prim::kPrimMakeTuple)}; + std::vector value_seq_inputs; + if (value_seq->isa()) { + (void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeList)); + } else { + (void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + } for (auto inner_value : values) { auto inner_value_seq = ProcessValueSequence(inner_value); (void)value_seq_inputs.emplace_back(inner_value_seq); @@ -1091,60 +1010,14 @@ class CleanAfterOptARewriter : public BaseRewriter { return ConvertInterpretedObjectValue(value_node, value->cast()); } } - bool need_convert = false; - auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert); - if (need_convert) { - return std::make_shared(convert_value); - } return nullptr; } - // AbstractSequence, AbstractRowTensor --> AbstractTuple. + // AbstractRowTensor --> AbstractTuple. static AbstractBasePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) { if (depth > kMaxSeqRecursiveDepth) { MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels."; } - // AbstractList --> AbstractTuple. - auto abs_seq = abs->cast(); - if (abs_seq != nullptr) { - // Dynamic length sequence do not convert. - if (abs_seq->dynamic_len() && abs_seq->isa()) { - auto converted_abs_tuple = std::make_shared(abs_seq->elements(), abs_seq->sequence_nodes()); - converted_abs_tuple->set_dynamic_len(true); - converted_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs()); - return converted_abs_tuple; - } - const auto &seq_elements = abs_seq->elements(); - // First we check if elements should be converted, - // changed_elements maps old element to new element. - mindspore::HashMap changed_elements; - for (const auto &element : seq_elements) { - auto new_element = ConvertToAbstractTuple(element, depth + 1); - if (new_element != nullptr) { - (void)changed_elements.emplace(element, new_element); - } - } - if (changed_elements.empty()) { - if (abs->isa()) { - // If no elements changed and it is an AbstractTuple, do not convert. - return nullptr; - } - // If no elements changed but it is not an AbstractTuple, convert it by copy elements. - return std::make_shared(seq_elements); - } - // Always make new AbstractTuple when elements changed. - std::vector elements; - elements.reserve(seq_elements.size()); - for (const auto &element : seq_elements) { - auto iter = changed_elements.find(element); - if (iter != changed_elements.end()) { - (void)elements.emplace_back(iter->second); - } else { - (void)elements.emplace_back(element); - } - } - return std::make_shared(std::move(elements)); - } // AbstractRowTensor --> AbstractTuple. auto abs_row_tensor = abs->cast>(); if (abs_row_tensor != nullptr) { diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index f5840b32ded..200141e3680 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -502,6 +502,49 @@ py::object VectorToPyData(const Any &value) { } return ret; } +template +py::object AbstractSequenceToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) { + auto value_size = value_list.size(); + auto ret = T(value_size); + auto seq_abs = abs->cast(); + MS_EXCEPTION_IF_NULL(seq_abs); + bool dynamic_len = seq_abs->dynamic_len(); + auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs(); + if (dynamic_len || dynamic_len_element_abs != nullptr) { + if (dynamic_len_element_abs == nullptr) { + MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple."; + for (size_t i = 0; i < value_size; i++) { + ret[i] = BaseRefToPyData(value_list[i]); + } + return ret; + } + if (dynamic_len_element_abs->isa()) { + MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty sequence."; + return ret; + } + for (size_t i = 0; i < value_size; ++i) { + ret[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs); + } + return ret; + } + static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0"); + // If FALLBACK_RUNTIME is not enable + // The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None. + size_t ref_idx = 0; + for (size_t i = 0; i < seq_abs->size(); i++) { + auto elem_abs = seq_abs->elements()[i]; + if (elem_abs->isa() && !support_fallback_runtime) { + continue; + } + ret[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs); + ref_idx++; + } + if (ref_idx != value_size) { + MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got " + << ref_idx; + } + return ret; +} py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) { py::object ret; @@ -516,55 +559,16 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr } // Current VectorRef reflects a COOTensor type - MS_LOG(DEBUG) << "abs: " << abs->ToString(); if (abs->isa()) { return MakeCSRTensor(value_list); } if (abs->isa()) { return MakeCOOTensor(value_list); } - if (!abs->isa()) { - return VectorRefToPyData(value_list, nullptr); + if (abs->isa()) { + return AbstractSequenceToPyData(value_list, abs); } - auto seq_abs = abs->cast(); - MS_EXCEPTION_IF_NULL(seq_abs); - bool dynamic_len = seq_abs->dynamic_len(); - auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs(); - if (dynamic_len || dynamic_len_element_abs != nullptr) { - if (dynamic_len_element_abs == nullptr) { - MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple."; - for (size_t i = 0; i < value_size; i++) { - ref_tuple[i] = BaseRefToPyData(value_list[i]); - } - return ref_tuple; - } - if (dynamic_len_element_abs->isa()) { - MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty tuple."; - return ref_tuple; - } - for (size_t i = 0; i < value_size; ++i) { - ref_tuple[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs); - } - return ref_tuple; - } - static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0"); - // If FALLBACK_RUNTIME is not enable - // The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None. - size_t ref_idx = 0; - for (size_t i = 0; i < seq_abs->size(); i++) { - auto elem_abs = seq_abs->elements()[i]; - if (elem_abs->isa() && !support_fallback_runtime) { - continue; - } - ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs); - ref_idx++; - } - if (ref_idx != value_size) { - MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got " - << ref_idx; - } - ret = ref_tuple; - return ret; + return AbstractSequenceToPyData(value_list, abs); } bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args, diff --git a/tests/st/fallback/test_graph_fallback_return_list.py b/tests/st/fallback/test_graph_fallback_return_list.py new file mode 100644 index 00000000000..920c54a222e --- /dev/null +++ b/tests/st/fallback/test_graph_fallback_return_list.py @@ -0,0 +1,558 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test return list type object from graph""" +import os +import pytest +import numpy as np + +import mindspore.common.dtype as mstype +from mindspore import ops +from mindspore.common import mutable +from mindspore import Tensor, jit, context + +context.set_context(mode=context.GRAPH_MODE) + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [1, 2, 3, 4] + + res = foo() + assert res == [1, 2, 3, 4] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_2(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return ["a", "b", "c", "d"] + + res = foo() + assert res == ["a", "b", "c", "d"] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_3(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [True, False, False, True] + + res = foo() + assert res == [True, False, False, True] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_4(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [Tensor([1]), Tensor([1, 2, 3]), Tensor([2, 3])] + + res = foo() + assert len(res) == 3 + assert np.all(res[0].asnumpy() == np.array([1])) + assert np.all(res[1].asnumpy() == np.array([1, 2, 3])) + assert np.all(res[2].asnumpy() == np.array([2, 3])) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_5(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [None, None, None] + + res = foo() + assert res == [None, None, None] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_6(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [np.array([1, 2, 3]), np.array([4, 5, 6]), 1] + + res = foo() + assert isinstance(res, list) + assert len(res) == 3 + assert np.all(res[0] == np.array([1, 2, 3])) + assert np.all(res[1] == np.array([4, 5, 6])) + assert res[2] == 1 + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_constant_list_7(): + """ + Feature: Return list in graph + Description: Support return constant list. + Expectation: No exception. + """ + @jit + def foo(): + return [1, "a", True, None, Tensor([2])] + + res = foo() + assert res == [1, "a", True, None, Tensor([2])] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_make_list_node(): + """ + Feature: Return list in graph + Description: Support return make list node. + Expectation: No exception. + """ + os.environ["GRAPH_OP_RUN"] = "1" + @jit + def foo(x): + return [x, x+1, x+2, 1] + + res = foo(mutable(1)) + assert res == [1, 2, 3, 1] + os.environ["GRAPH_OP_RUN"] = "0" + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_make_list_node_2(): + """ + Feature: Return list in graph + Description: Support return make list node. + Expectation: No exception. + """ + @jit + def foo(x): + return [x, x+1, x+2, Tensor([4])] + + res = foo(Tensor([1])) + assert res == [Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_make_list_node_3(): + """ + Feature: Return list in graph + Description: Support return make list node. + Expectation: No exception. + """ + @jit + def foo(x): + return [x, mutable(1), "a"] + + res = foo(Tensor([1])) + assert res == [Tensor([1]), 1, "a"] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_with_nest(): + """ + Feature: Return list in graph + Description: Support return make list in nest scene. + Expectation: No exception. + """ + @jit + def foo(): + return [[1, 2, 3], [4, 5, 6]] + + res = foo() + assert res == [[1, 2, 3], [4, 5, 6]] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_with_nest_2(): + """ + Feature: Return list in graph + Description: Support return make list in nest scene. + Expectation: No exception. + """ + @jit + def foo(): + return [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]] + + res = foo() + assert res == [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_with_nest_3(): + """ + Feature: Return list in graph + Description: Support return make list in nest scene. + Expectation: No exception. + """ + @jit + def foo(): + return (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]) + + res = foo() + assert res == (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_make_list_with_nest(): + """ + Feature: Return list in graph + Description: Support return make list in nest scene. + Expectation: No exception. + """ + @jit + def foo(x): + return [[x, x], (x+1, x+2)] + + res = foo(Tensor([0])) + assert res == [[Tensor([0]), Tensor([0])], (Tensor([1]), Tensor([2]))] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_make_list_with_nest_2(): + """ + Feature: Return list in graph + Description: Support return make list in nest scene. + Expectation: No exception. + """ + @jit + def foo(x): + return [x, ([x, 1],)], (x+1, x+2) + + res = foo(Tensor([0])) + assert res == ([Tensor([0]), ([Tensor([0]), 1],)], (Tensor([1]), Tensor([2]))) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_buildin_list_func(): + """ + Feature: Return list in graph + Description: Support return result of list() function. + Expectation: No exception. + """ + @jit + def foo(): + return list((1, "2", None, Tensor([1]))) + + res = foo() + assert res == [1, "2", None, Tensor([1])] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_buildin_list_func_2(): + """ + Feature: Return list in graph + Description: Support return result of list() function. + Expectation: No exception. + """ + @jit + def foo(x): + return list(x) + + res = foo(Tensor([1, 2, 3])) + assert res == [Tensor([1]), Tensor([2]), Tensor([3])] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_dynamic_length_list(): + """ + Feature: Return list in graph + Description: Support return dynamic length list. + Expectation: No exception. + """ + @jit + def foo(): + x = mutable([1, 2, 3], True) + return x + + res = foo() + assert res == [1, 2, 3] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_dynamic_length_list_2(): + """ + Feature: Return list in graph + Description: Support return dynamic length list. + Expectation: No exception. + """ + @jit + def foo(m): + x = mutable([m, m+1], True) + return x + + res = foo(Tensor([0])) + assert res == [Tensor([0]), Tensor([1])] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_from_third_party(): + """ + Feature: Return list in graph + Description: Support return list from third party. + Expectation: No exception. + """ + @jit + def foo(): + m = np.array([1, 2, 3, 4]) + x = m.tolist() + return x + + res = foo() + assert res == [1, 2, 3, 4] + + +@pytest.mark.skip(reason="Getattr for interpret node failed.") +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_from_third_party_2(): + """ + Feature: Return list in graph + Description: Support return list from third party. + Expectation: No exception. + """ + @jit + def foo(m): + x = m.asnumpy().tolist() + return x + + res = foo(Tensor([1, 2, 3, 4])) + assert res == [1, 2, 3, 4] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_from_third_party_3(): + """ + Feature: Return list in graph + Description: Support return list from third party. + Expectation: No exception. + """ + @jit + def foo(): + x = np.arange(0, 10, 2) + return list(x) + + res = foo() + assert res == [0, 2, 4, 6, 8] + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_from_dict_attribute(): + """ + Feature: Return list in graph + Description: Support return list from dict keys and values. + Expectation: No exception. + """ + @jit + def foo(x, y): + m = {"1": x, "2": y} + return list(m.keys()), list(m.values()) + + res = foo(Tensor([1]), mutable(2)) + assert len(res) == 2 + assert res[0] == ["1", "2"] + assert res[1] == [Tensor([1]), 2] + + +@pytest.mark.skip(reason="Return dict change the abstract.") +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_return_list_from_dict_attribute_2(): + """ + Feature: Return list in graph + Description: Support return list from dict keys and values. + Expectation: No exception. + """ + @jit + def foo(x, y): + m = {"1": x, "2": y} + return m, list(m.keys()), list(m.values()) + + res = foo(Tensor([1]), mutable(2)) + assert len(res) == 3 + assert res[1] == ["1", "2"] + assert res[2] == [Tensor([1]), 2] + + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_for_return_list_graph(): + """ + Feature: Return list in graph + Description: Support calculate gradient for graph with list return. + Expectation: No exception. + """ + @jit + def foo(x): + y = ops.ReLU()(x) + return [y,] + + x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32) + res = ops.grad(foo)(x) + assert np.allclose(res.asnumpy(), np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).astype(np.float32)) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_for_graph_with_list_input(): + """ + Feature: Return list in graph + Description: Support calculate gradient for graph with list return. + Expectation: No exception. + """ + @jit + def foo(t): + x = t[0] + y = t[1] + out = ops.MatMul()(x, y) + return out + + t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), + Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) + output = ops.grad(foo)(t) + assert isinstance(output, list) + expect = [np.array([[1.4100001, 1.5999999, 6.6], + [1.4100001, 1.5999999, 6.6]]).astype(np.float32), + np.array([[1.7, 1.7, 1.7], + [1.9, 1.9, 1.9], + [1.5, 1.5, 1.5]]).astype(np.float32)] + assert np.allclose(output[0].asnumpy(), expect[0]) + assert np.allclose(output[1].asnumpy(), expect[1]) diff --git a/tests/st/graph_syntax/python_builtin_functions/test_list_tuple.py b/tests/st/graph_syntax/python_builtin_functions/test_list_tuple.py index 2c12a09525e..d323bac222c 100644 --- a/tests/st/graph_syntax/python_builtin_functions/test_list_tuple.py +++ b/tests/st/graph_syntax/python_builtin_functions/test_list_tuple.py @@ -37,7 +37,7 @@ def test_fallback_list_with_input_constant_tensor(): x.append(Tensor([4])) return x out = foo() - assert isinstance(out, tuple) + assert isinstance(out, list) assert len(out) == 4 assert isinstance(out[0], Tensor) assert out[0].asnumpy() == 1 @@ -66,7 +66,7 @@ def test_fallback_list_with_input_constant_tensor_2(): x.append(Tensor([5, 6])) return x out = foo() - assert isinstance(out, tuple) + assert isinstance(out, list) assert len(out) == 3 assert isinstance(out[0], Tensor) assert np.allclose(out[0].asnumpy(), np.array([1, 2])) @@ -139,7 +139,7 @@ def test_fallback_tuple_with_input_constant_tensor_2(): x = list(Tensor([[1, 2], [3, 4]])) return x out = foo() - assert isinstance(out, tuple) + assert isinstance(out, list) assert len(out) == 2 assert isinstance(out[0], Tensor) assert np.allclose(out[0].asnumpy(), np.array([1, 2])) diff --git a/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py b/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py index 69f3ef42759..93925bcb73d 100644 --- a/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py +++ b/tests/st/graph_syntax/statements/test_graph_raise_with_variable.py @@ -269,6 +269,7 @@ def test_raise_with_variable_control_flow3(): raise_info_joinedstr_tensor.value) +@pytest.mark.skip(reason="not support yet") @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard diff --git a/tests/st/mutable/test_grad_mutable.py b/tests/st/mutable/test_grad_mutable.py index c4eebcc7359..28b53154ec6 100644 --- a/tests/st/mutable/test_grad_mutable.py +++ b/tests/st/mutable/test_grad_mutable.py @@ -119,7 +119,7 @@ def test_grad_mutable_list_tensor(): t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) output = GradNetWrtX(Net())(t) - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[1.7, 1.7, 1.7], @@ -310,7 +310,7 @@ def test_grad_mutable_list_tuple_tensor(): Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)), Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]) output = GradNetWrtX(Net())(t) - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [[np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0], [0, 0, 0]]).astype(np.float32)], @@ -423,6 +423,7 @@ def test_grad_mutable_dict_tuple_tensor(): assert compare(output, expect) +@pytest.mark.skip(reason="Do not support yet") @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -467,7 +468,7 @@ def test_grad_mutable_list_dict_tensor(): np.array([[1.7, 1.7, 1.7], [1.9, 1.9, 1.9], [1.5, 1.5, 1.5]]).astype(np.float32)] - assert isinstance(output, tuple) + assert isinstance(output, list) assert len(output) == 2 assert isinstance(output[0], dict) assert len(output[0].keys()) == 2 @@ -579,7 +580,7 @@ def test_grad_mutable_list_tensor_jit_function(): context.set_context(mode=context.GRAPH_MODE) output = GradOperation()(net)(z) - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[1.7, 1.7, 1.7], @@ -673,7 +674,7 @@ def test_grad_mutable_unused_list_tensor(): Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32), Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)]) output = GradNetWrtX(Net())(t) - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [np.array([[3., 3., 3.], [3., 3., 3.]]).astype(np.float32), np.array([[0., 0., 0.], diff --git a/tests/st/mutable/test_mutable_in_graph.py b/tests/st/mutable/test_mutable_in_graph.py index e48bfd49299..d92787398c0 100644 --- a/tests/st/mutable/test_mutable_in_graph.py +++ b/tests/st/mutable/test_mutable_in_graph.py @@ -279,7 +279,7 @@ def test_grad_const_list_tensor_to_mutable(): grad_net = GradNetWrtX(Net()) output = grad_net() - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[1.7, 1.7, 1.7], @@ -288,7 +288,7 @@ def test_grad_const_list_tensor_to_mutable(): assert compare(output, expect) grad_net = GradNetWrtX1(Net()) output = grad_net() - assert isinstance(output, tuple) + assert isinstance(output, list) assert compare(output, expect) @@ -338,7 +338,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable(): x = [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32), Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)] output = grad_net(x) - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[1.7, 1.7, 1.7], @@ -347,6 +347,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable(): assert compare(output, expect) +@pytest.mark.skip(reason="not support yet.") @pytest.mark.level0 @pytest.mark.platform_x86_cpu @pytest.mark.env_onecard @@ -397,7 +398,7 @@ def test_grad_const_list_and_tuple_tensor_to_mutable(): grad_net = GradNetWrtX(Net()) output = grad_net() - assert isinstance(output, tuple) + assert isinstance(output, list) expect = [(np.array([[1.4100001, 1.5999999, 6.6], [1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00], @@ -405,10 +406,10 @@ def test_grad_const_list_and_tuple_tensor_to_mutable(): np.array([[1.7, 1.7, 1.7], [1.9, 1.9, 1.9], [1.5, 1.5, 1.5]]).astype(np.float32)] - assert compare(output, expect) + assert compare(output, list) grad_net = GradNetWrtX1(Net()) output = grad_net() - assert isinstance(output, tuple) + assert isinstance(output, list) assert compare(output, expect) diff --git a/tests/st/runtime/test_runtime_output.py b/tests/st/runtime/test_runtime_output.py index 9d5b72533aa..069c74ec5bb 100644 --- a/tests/st/runtime/test_runtime_output.py +++ b/tests/st/runtime/test_runtime_output.py @@ -53,7 +53,7 @@ def test_value_node_with_depend(): x = [[1, 2, 3, 4], [5, 6, 7, 8]] net = NetValueNodeWithDepend() output = net(x) - assert output == (5, 0, 7, 8) + assert output == [5, 0, 7, 8] @pytest.mark.level0 diff --git a/tests/ut/python/fallback/control_flow/test_fallback_000_single_if.py b/tests/ut/python/fallback/control_flow/test_fallback_000_single_if.py index c119b40a438..5c345464f7d 100644 --- a/tests/ut/python/fallback/control_flow/test_fallback_000_single_if.py +++ b/tests/ut/python/fallback/control_flow/test_fallback_000_single_if.py @@ -41,7 +41,7 @@ def test_single_if_no_else_type(): test_net = FalseNet() res = test_net() - assert str(res) == "(, )" + assert str(res) == "[, ]" def test_single_if_no_else_type_2(): @@ -64,7 +64,7 @@ def test_single_if_no_else_type_2(): test_net = TrueNet() res = test_net() - assert str(res) == "(, )" + assert str(res) == "[, ]" def test_single_if_1(): diff --git a/tests/ut/python/graph_syntax/dict/test_dictionary.py b/tests/ut/python/graph_syntax/dict/test_dictionary.py index f10c28b6dd3..cf362995f32 100644 --- a/tests/ut/python/graph_syntax/dict/test_dictionary.py +++ b/tests/ut/python/graph_syntax/dict/test_dictionary.py @@ -61,7 +61,7 @@ def test_dict1(): input_me = Tensor(input_np) net = Net1() out_me = net(input_me) - assert out_me == ('x', 'y', 0, 1) + assert out_me == ['x', 'y', 0, 1] def test_dict2(): @@ -76,7 +76,7 @@ def test_dict3(): input_me = Tensor(input_np) net = Net3() out_me = net(input_me) - assert out_me == ('x', 'y', 0, (0, 1)) + assert out_me == ['x', 'y', 0, (0, 1)] def test_dict4(): diff --git a/tests/ut/python/graph_syntax/list/test_list_assign.py b/tests/ut/python/graph_syntax/list/test_list_assign.py index d2e860dc9b9..1e96adc5a12 100644 --- a/tests/ut/python/graph_syntax/list/test_list_assign.py +++ b/tests/ut/python/graph_syntax/list/test_list_assign.py @@ -576,7 +576,7 @@ def test_list_double_slice(): class NetInner(Cell): def construct(self, a, b, start1, stop1, step1, start2, stop2, step2): a[start1:stop1:step1][start2: stop2: step2] = b - return tuple(a) + return a net = NetInner() a = [1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -728,7 +728,7 @@ def test_list_slice_negetive_step(): a = [1, 2, 3, 4, 5] b = [11, 22, 33, 44, 55] a[-1:-4:-1] = b[-1:-4:-1] - return tuple(a) + return a x = py_func() y = ms_func() @@ -771,6 +771,6 @@ def test_list_slice_only_with_step(): a = [1, 2, 3, 4] b = [11, 22] a[::2] = b - return tuple(a) + return a assert ms_func() == py_func() diff --git a/tests/ut/python/graph_syntax/list/test_list_extend.py b/tests/ut/python/graph_syntax/list/test_list_extend.py index 13a6359c745..3ae43f7d2ba 100644 --- a/tests/ut/python/graph_syntax/list/test_list_extend.py +++ b/tests/ut/python/graph_syntax/list/test_list_extend.py @@ -33,7 +33,7 @@ def test_list_extend_1(): x.extend(y) return x out = list_net_1() - assert np.all(out == (1, 2, 3, 4, 5, 6, 7)) + assert np.all(out == [1, 2, 3, 4, 5, 6, 7]) def test_list_extend_2(): @@ -52,7 +52,7 @@ def test_list_extend_2(): x.extend(z) return x out = list_net_2() - assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 20)) + assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 20]) def test_list_extend_3(): @@ -73,8 +73,8 @@ def test_list_extend_3(): x.extend(z) return x out = list_net_3() - assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \ - 20, 4, Tensor(1), (1, 2), Tensor(1))) + assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \ + 20, 4, Tensor(1), (1, 2), Tensor(1)]) def test_list_extend_4(): @@ -90,7 +90,7 @@ def test_list_extend_4(): x.extend(y) return x out = list_net_4() - assert np.all(out == ()) + assert np.all(out == []) def test_list_extend_tuple(): @@ -107,4 +107,4 @@ def test_list_extend_tuple(): return x out = func() - assert np.all(out == (1, 2, 3, 4, 5, 6, 7)) + assert np.all(out == [1, 2, 3, 4, 5, 6, 7]) diff --git a/tests/ut/python/graph_syntax/list/test_list_insert.py b/tests/ut/python/graph_syntax/list/test_list_insert.py index 36c8eea96dd..fae3f409f3f 100644 --- a/tests/ut/python/graph_syntax/list/test_list_insert.py +++ b/tests/ut/python/graph_syntax/list/test_list_insert.py @@ -132,7 +132,7 @@ def test_list_insert_pop_2(): return x, y res_x, res_y = list_insert_pop(-2) - assert np.all(res_x == (3, 1, 4)) + assert np.all(res_x == [3, 1, 4]) assert res_y == 3 diff --git a/tests/ut/python/graph_syntax/list/test_list_mul_number.py b/tests/ut/python/graph_syntax/list/test_list_mul_number.py index daa5c8bc7cb..bd8d3ecf900 100644 --- a/tests/ut/python/graph_syntax/list/test_list_mul_number.py +++ b/tests/ut/python/graph_syntax/list/test_list_mul_number.py @@ -40,9 +40,9 @@ def test_list_mul_number(): """ net = Net() expect_ret0 = [Tensor([1, 2, 3])] * 5 - expect_ret1 = (Tensor([1, 2, 3]),) * 0 - assert isinstance(net()[0], tuple) - assert isinstance(net()[1], tuple) + expect_ret1 = [Tensor([1, 2, 3]),] * 0 + assert isinstance(net()[0], list) + assert isinstance(net()[1], list) for i in range(len(net()[0])): assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) assert net()[1] == expect_ret1 diff --git a/tests/ut/python/graph_syntax/list/test_list_pop.py b/tests/ut/python/graph_syntax/list/test_list_pop.py index 4f39e18add5..6052e6da30e 100644 --- a/tests/ut/python/graph_syntax/list/test_list_pop.py +++ b/tests/ut/python/graph_syntax/list/test_list_pop.py @@ -33,7 +33,7 @@ def test_list_pop_1(): return x, y res_x, res_y = list_pop() - assert np.all(res_x == (1, 2, 3)) + assert np.all(res_x == [1, 2, 3]) assert res_y == 4 @@ -50,7 +50,7 @@ def test_list_pop_2(): return x, y res_x, res_y = list_pop() - assert np.all(res_x == (1, 2, 4)) + assert np.all(res_x == [1, 2, 4]) assert res_y == 3 @@ -67,7 +67,7 @@ def test_list_pop_3(): return x, y res_x, res_y = list_pop() - assert np.all(res_x == (1, 3, 4)) + assert np.all(res_x == [1, 3, 4]) assert res_y == 2 @@ -140,8 +140,8 @@ def test_list_pop_7(): return x1, x2, y1 + y2 res_x1, res_x2, res_y = list_pop() - assert np.all(res_x1 == (1, 3, 4)) - assert np.all(res_x2 == (5, 6, 8)) + assert np.all(res_x1 == [1, 3, 4]) + assert np.all(res_x2 == [5, 6, 8]) assert res_y == 9 @@ -158,7 +158,7 @@ def test_list_pop_8(): return x, y res_x, res_y = list_pop(2) - assert res_x == (Tensor([1]), Tensor([2])) + assert res_x == [Tensor([1]), Tensor([2])] assert res_y == Tensor([3]) @@ -175,7 +175,7 @@ def test_list_pop_9(): input_x = [Tensor([1]), Tensor([2]), Tensor([3])] res_x, res_y = list_pop(input_x, 2) - assert res_x == (Tensor([1]), Tensor([2])) + assert res_x == [Tensor([1]), Tensor([2])] assert res_y == Tensor([3]) diff --git a/tests/ut/python/graph_syntax/list/test_list_reverse.py b/tests/ut/python/graph_syntax/list/test_list_reverse.py index 7af782840e5..17dc5f2b8bb 100644 --- a/tests/ut/python/graph_syntax/list/test_list_reverse.py +++ b/tests/ut/python/graph_syntax/list/test_list_reverse.py @@ -32,7 +32,7 @@ def test_list_reverse_1(): x.reverse() return x out = list_net_1() - assert np.all(out == (4, 3, 2, 1)) + assert np.all(out == [4, 3, 2, 1]) def test_list_reverse_2(): @@ -48,7 +48,7 @@ def test_list_reverse_2(): x.reverse() return x out = list_net_2() - assert np.all(out == (4, 20, ('bb', '2', 3), 'a')) + assert np.all(out == [4, 20, ('bb', '2', 3), 'a']) def test_list_reverse_3(): @@ -65,7 +65,7 @@ def test_list_reverse_3(): x.reverse() return x out = list_net_3() - assert np.all(out == (Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a')) + assert np.all(out == [Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a']) def test_list_reverse_4(): @@ -80,4 +80,4 @@ def test_list_reverse_4(): x.reverse() return x out = list_net_4() - assert np.all(out == ()) + assert np.all(out == []) diff --git a/tests/ut/python/graph_syntax/list/test_number_mul_list.py b/tests/ut/python/graph_syntax/list/test_number_mul_list.py index bf413479e5e..c40822f747f 100644 --- a/tests/ut/python/graph_syntax/list/test_number_mul_list.py +++ b/tests/ut/python/graph_syntax/list/test_number_mul_list.py @@ -39,9 +39,9 @@ def test_number_mul_list(): context.set_context(mode=context.GRAPH_MODE) net = Net() expect_ret0 = 5 * [Tensor([1, 2, 3])] - expect_ret1 = 0 * (Tensor([1, 2, 3]),) - assert isinstance(net()[0], tuple) - assert isinstance(net()[1], tuple) + expect_ret1 = 0 * [Tensor([1, 2, 3]),] + assert isinstance(net()[0], list) + assert isinstance(net()[1], list) for i in range(len(net()[0])): assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy()) assert net()[1] == expect_ret1 diff --git a/tests/ut/python/graph_syntax/python_builtin_functions/test_dict.py b/tests/ut/python/graph_syntax/python_builtin_functions/test_dict.py index a0416c1778e..df5cc05ab0d 100644 --- a/tests/ut/python/graph_syntax/python_builtin_functions/test_dict.py +++ b/tests/ut/python/graph_syntax/python_builtin_functions/test_dict.py @@ -30,7 +30,7 @@ def test_fallback_dict_empty(): dict_x['a'] = [1, 2, 3] return dict_x["a"] - assert foo() == (1, 2, 3) + assert foo() == [1, 2, 3] def test_fallback_dict_zip_iter_assign(): diff --git a/tests/ut/python/graph_syntax/python_builtin_functions/test_graph_fallback_python_builtin.py b/tests/ut/python/graph_syntax/python_builtin_functions/test_graph_fallback_python_builtin.py index 7d11dcffc02..584ae5b4164 100644 --- a/tests/ut/python/graph_syntax/python_builtin_functions/test_graph_fallback_python_builtin.py +++ b/tests/ut/python/graph_syntax/python_builtin_functions/test_graph_fallback_python_builtin.py @@ -137,7 +137,7 @@ def test_fallback_reversed(): def foo(): x = reversed([1, 2, 3]) return list(x) - assert foo() == (3, 2, 1) + assert foo() == [3, 2, 1] def test_fallback_set(): diff --git a/tests/ut/python/graph_syntax/python_builtin_functions/test_list_tuple.py b/tests/ut/python/graph_syntax/python_builtin_functions/test_list_tuple.py index 552559f0cf9..3f60dfae147 100644 --- a/tests/ut/python/graph_syntax/python_builtin_functions/test_list_tuple.py +++ b/tests/ut/python/graph_syntax/python_builtin_functions/test_list_tuple.py @@ -35,8 +35,8 @@ def test_fallback_list_with_input_tuple(): x.append(4) return x out = foo() - assert isinstance(out, tuple) - assert operator.eq(out, (1, 2, 3, 4)) + assert isinstance(out, list) + assert operator.eq(out, [1, 2, 3, 4]) def test_fallback_list_with_input_list(): @@ -51,8 +51,8 @@ def test_fallback_list_with_input_list(): x.append(4) return x out = foo() - assert isinstance(out, tuple) - assert operator.eq(out, (1, 2, 3, 4)) + assert isinstance(out, list) + assert operator.eq(out, [1, 2, 3, 4]) def test_fallback_list_with_input_dict(): @@ -67,8 +67,8 @@ def test_fallback_list_with_input_dict(): x.append('d') return x out = foo() - assert isinstance(out, tuple) - assert operator.eq(out, ('a', 'b', 'c', 'd')) + assert isinstance(out, list) + assert operator.eq(out, ['a', 'b', 'c', 'd']) def test_fallback_list_with_input_numpy_array(): diff --git a/tests/ut/python/graph_syntax/python_builtin_functions/test_max_min.py b/tests/ut/python/graph_syntax/python_builtin_functions/test_max_min.py index df9119530b0..b1d2048e548 100644 --- a/tests/ut/python/graph_syntax/python_builtin_functions/test_max_min.py +++ b/tests/ut/python/graph_syntax/python_builtin_functions/test_max_min.py @@ -131,7 +131,7 @@ def test_fallback_max_with_two_inputs_list(): x = max([1, 2, 3], [4, 5]) return x out = foo() - assert operator.eq(out, (4, 5)) + assert operator.eq(out, [4, 5]) def test_fallback_min_with_two_inputs_list(): @@ -145,7 +145,7 @@ def test_fallback_min_with_two_inputs_list(): x = min([1, 2, 3], [4, 5]) return x out = foo() - assert operator.eq(out, (1, 2, 3)) + assert operator.eq(out, [1, 2, 3]) def test_builtin_function_max_min_with_string(): diff --git a/tests/ut/python/graph_syntax/statements/test_list_comp.py b/tests/ut/python/graph_syntax/statements/test_list_comp.py index cd76bec2b30..ecdbda5d88e 100644 --- a/tests/ut/python/graph_syntax/statements/test_list_comp.py +++ b/tests/ut/python/graph_syntax/statements/test_list_comp.py @@ -53,21 +53,21 @@ def get_list_comp_5(): @jit def get_generator_exp_1(): t = (x for x in range(1, 6)) - return t + return tuple(t) @jit def get_generator_exp_2(): t = (x * x for x in range(1, 11) if x > 5 if x % 2 == 0) - return t + return tuple(t) def test_list_comp(): context.set_context(mode=context.GRAPH_MODE) - assert get_list_comp_1() == (1, 2, 3, 4, 5) - assert get_list_comp_2() == (1, 4, 9, 16, 25) - assert get_list_comp_3() == (4, 16, 36, 64, 100) - assert get_list_comp_4() == (36, 64, 100) + assert get_list_comp_1() == [1, 2, 3, 4, 5] + assert get_list_comp_2() == [1, 4, 9, 16, 25] + assert get_list_comp_3() == [4, 16, 36, 64, 100] + assert get_list_comp_4() == [36, 64, 100] with pytest.raises(TypeError) as ex: get_list_comp_5() assert "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp" in str(ex.value) diff --git a/tests/ut/python/ops/test_filter.py b/tests/ut/python/ops/test_filter.py index 84b08ca76e6..cf8b2bd24de 100644 --- a/tests/ut/python/ops/test_filter.py +++ b/tests/ut/python/ops/test_filter.py @@ -41,4 +41,4 @@ class NetWork(Cell): list1 = [1, 2, 3] net1 = NetWork() result = net1(list1) -assert result == (1, 3) +assert result == [1, 3] diff --git a/tests/ut/python/pipeline/infer/test_hypermap.py b/tests/ut/python/pipeline/infer/test_hypermap.py index da3721060a2..ebde5c8d8ec 100644 --- a/tests/ut/python/pipeline/infer/test_hypermap.py +++ b/tests/ut/python/pipeline/infer/test_hypermap.py @@ -58,7 +58,7 @@ def test_hypermap_value(): return map(lambda a: a + 1, self._list) net = Net() - assert net() == (23, 67, 89, 112) + assert net() == [23, 67, 89, 112] def test_hypermap_func_const(): @@ -80,4 +80,4 @@ def test_hypermap_func_const(): return map(lambda f: f(4), _list) net = NetMap() - assert net() == (8, 12, 16) + assert net() == [8, 12, 16] diff --git a/tests/ut/python/pipeline/parse/test_structure_output.py b/tests/ut/python/pipeline/parse/test_structure_output.py index 8d20c944da7..98053554f25 100644 --- a/tests/ut/python/pipeline/parse/test_structure_output.py +++ b/tests/ut/python/pipeline/parse/test_structure_output.py @@ -66,7 +66,7 @@ def test_output_const_list(): return ret net = Net() - assert net() == (1, 2, 3) + assert net() == [1, 2, 3] def test_output_const_int():