forked from mindspore-Ecosystem/mindspore
Enable list as output
This commit is contained in:
parent
6212278df0
commit
c30a1c7018
|
@ -40,6 +40,21 @@ bool ConvertListToTuple::Run(const FuncGraphPtr &graph) {
|
|||
<< ",debug name:" << node->DebugString();
|
||||
}
|
||||
}
|
||||
auto manager = graph->manager();
|
||||
MS_EXCEPTION_IF_NULL(manager);
|
||||
for (auto node : graph->parameters()) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
// Convert unused list parameter to tuple.
|
||||
if (manager->node_users().find(node) != manager->node_users().end()) {
|
||||
continue;
|
||||
}
|
||||
auto new_abs = ConvertSequenceAbsToTupleAbs(node->abstract());
|
||||
if (new_abs != nullptr) {
|
||||
node->set_abstract(new_abs);
|
||||
MS_LOG(INFO) << "Convert sequence abstract to tuple abstract for op:" << node->fullname_with_scope()
|
||||
<< ",debug name:" << node->DebugString();
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -613,63 +613,6 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
~CleanAfterOptARewriter() override = default;
|
||||
|
||||
protected:
|
||||
// From:
|
||||
// MakeList(arg1, arg2, ...)
|
||||
// To:
|
||||
// MakeTuple(arg1, arg2, ...)
|
||||
AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.reserve(node->size());
|
||||
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
// Inputs of node should be [make_list, item1, item2, ...], so offset by 1 to get items;
|
||||
(void)inputs.insert(inputs.cend(), node->inputs().cbegin() + 1, node->inputs().cend());
|
||||
return node->func_graph()->NewCNode(std::move(inputs));
|
||||
}
|
||||
|
||||
// From:
|
||||
// list_getitem(list, key)
|
||||
// To:
|
||||
// TupleGetItem(list, key)
|
||||
AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
// Inputs should be [list_getitem, list, item]
|
||||
constexpr size_t expect_input_size = 3;
|
||||
CheckInputsSize(node, expect_input_size);
|
||||
constexpr size_t data_index = 1;
|
||||
constexpr size_t cons_index = 2;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &key = inputs[cons_index];
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, key});
|
||||
}
|
||||
|
||||
// From:
|
||||
// ListSetItem(list, index, item)
|
||||
// To:
|
||||
// TupleSetItem(list, index, item)
|
||||
AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
MS_EXCEPTION_IF_NULL(node->func_graph());
|
||||
|
||||
// Inputs should be [list_setitem, list, index, item]
|
||||
const size_t expect_inputs_size = 4;
|
||||
CheckInputsSize(node, expect_inputs_size);
|
||||
|
||||
const size_t data_index = 1;
|
||||
const size_t cons_index = 2;
|
||||
const size_t value_index = 3;
|
||||
const auto &inputs = node->inputs();
|
||||
auto &data = inputs[data_index];
|
||||
auto &key = inputs[cons_index];
|
||||
auto &value = inputs[value_index];
|
||||
return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, key, value});
|
||||
}
|
||||
|
||||
// From:
|
||||
// MakeSparseTensor(indices, values, dense_shape)
|
||||
// To:
|
||||
|
@ -922,9 +865,6 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
using Converter = AnfNodePtr (ThisClass::*)(const CNodePtr &);
|
||||
using ConverterMap = mindspore::HashMap<PrimitivePtr, Converter, PrimitiveHasher, PrimitiveEqual>;
|
||||
static inline const ConverterMap converters_{
|
||||
{prim::kPrimMakeList, &ThisClass::ConvertMakeListToMakeTuple},
|
||||
{prim::kPrimListGetItem, &ThisClass::ConvertListGetItemToTupleGetItem},
|
||||
{prim::kPrimListSetItem, &ThisClass::ConvertListSetItemToTupleSetItem},
|
||||
// SparseProcess: 1.MakeSparse->MakeTuple 2.SparseGetAttr->TupleGetItem
|
||||
{prim::kPrimMakeRowTensor, &ThisClass::ConvertMakeSparseToMakeTuple},
|
||||
{prim::kPrimRowTensorGetIndices, &ThisClass::ConvertSparseGetAttrToTupleGetItem},
|
||||
|
@ -994,39 +934,18 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return (this->*(iter->second))(cnode);
|
||||
}
|
||||
|
||||
static ValuePtr ConvertValueSequenceToValueTuple(const ValuePtr &value, size_t depth, bool *need_convert) {
|
||||
MS_EXCEPTION_IF_NULL(need_convert);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (depth > kMaxSeqRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||
}
|
||||
|
||||
if (value->isa<ValueSequence>()) {
|
||||
std::vector<ValuePtr> elements;
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
(void)std::transform(value_seq->value().begin(), value_seq->value().end(), std::back_inserter(elements),
|
||||
[&](const ValuePtr &value) -> ValuePtr {
|
||||
bool is_convert = false;
|
||||
auto convert_value = ConvertValueSequenceToValueTuple(value, depth + 1, &is_convert);
|
||||
*need_convert |= is_convert;
|
||||
return convert_value;
|
||||
});
|
||||
*need_convert |= value->isa<ValueList>();
|
||||
if (*need_convert) {
|
||||
return std::make_shared<ValueTuple>(elements);
|
||||
}
|
||||
}
|
||||
|
||||
return value;
|
||||
}
|
||||
|
||||
AnfNodePtr ProcessValueSequence(const ValuePtr &value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<ValueSequence>()) {
|
||||
auto value_seq = value->cast<ValueSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_seq);
|
||||
auto values = value_seq->value();
|
||||
std::vector<AnfNodePtr> value_seq_inputs{NewValueNode(prim::kPrimMakeTuple)};
|
||||
std::vector<AnfNodePtr> value_seq_inputs;
|
||||
if (value_seq->isa<ValueList>()) {
|
||||
(void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeList));
|
||||
} else {
|
||||
(void)value_seq_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
}
|
||||
for (auto inner_value : values) {
|
||||
auto inner_value_seq = ProcessValueSequence(inner_value);
|
||||
(void)value_seq_inputs.emplace_back(inner_value_seq);
|
||||
|
@ -1091,60 +1010,14 @@ class CleanAfterOptARewriter : public BaseRewriter {
|
|||
return ConvertInterpretedObjectValue(value_node, value->cast<parse::InterpretedObjectPtr>());
|
||||
}
|
||||
}
|
||||
bool need_convert = false;
|
||||
auto convert_value = ConvertValueSequenceToValueTuple(value, 0, &need_convert);
|
||||
if (need_convert) {
|
||||
return std::make_shared<ValueNode>(convert_value);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// AbstractSequence, AbstractRowTensor --> AbstractTuple.
|
||||
// AbstractRowTensor --> AbstractTuple.
|
||||
static AbstractBasePtr ConvertToAbstractTuple(const AbstractBasePtr &abs, size_t depth) {
|
||||
if (depth > kMaxSeqRecursiveDepth) {
|
||||
MS_LOG(EXCEPTION) << "List or Dict nesting is not allowed more than " << kMaxSeqRecursiveDepth << " levels.";
|
||||
}
|
||||
// AbstractList --> AbstractTuple.
|
||||
auto abs_seq = abs->cast<AbstractSequencePtr>();
|
||||
if (abs_seq != nullptr) {
|
||||
// Dynamic length sequence do not convert.
|
||||
if (abs_seq->dynamic_len() && abs_seq->isa<AbstractList>()) {
|
||||
auto converted_abs_tuple = std::make_shared<AbstractTuple>(abs_seq->elements(), abs_seq->sequence_nodes());
|
||||
converted_abs_tuple->set_dynamic_len(true);
|
||||
converted_abs_tuple->set_dynamic_len_element_abs(abs_seq->dynamic_len_element_abs());
|
||||
return converted_abs_tuple;
|
||||
}
|
||||
const auto &seq_elements = abs_seq->elements();
|
||||
// First we check if elements should be converted,
|
||||
// changed_elements maps old element to new element.
|
||||
mindspore::HashMap<AbstractBasePtr, AbstractBasePtr> changed_elements;
|
||||
for (const auto &element : seq_elements) {
|
||||
auto new_element = ConvertToAbstractTuple(element, depth + 1);
|
||||
if (new_element != nullptr) {
|
||||
(void)changed_elements.emplace(element, new_element);
|
||||
}
|
||||
}
|
||||
if (changed_elements.empty()) {
|
||||
if (abs->isa<AbstractTuple>()) {
|
||||
// If no elements changed and it is an AbstractTuple, do not convert.
|
||||
return nullptr;
|
||||
}
|
||||
// If no elements changed but it is not an AbstractTuple, convert it by copy elements.
|
||||
return std::make_shared<AbstractTuple>(seq_elements);
|
||||
}
|
||||
// Always make new AbstractTuple when elements changed.
|
||||
std::vector<AbstractBasePtr> elements;
|
||||
elements.reserve(seq_elements.size());
|
||||
for (const auto &element : seq_elements) {
|
||||
auto iter = changed_elements.find(element);
|
||||
if (iter != changed_elements.end()) {
|
||||
(void)elements.emplace_back(iter->second);
|
||||
} else {
|
||||
(void)elements.emplace_back(element);
|
||||
}
|
||||
}
|
||||
return std::make_shared<AbstractTuple>(std::move(elements));
|
||||
}
|
||||
// AbstractRowTensor --> AbstractTuple.
|
||||
auto abs_row_tensor = abs->cast<std::shared_ptr<AbstractRowTensor>>();
|
||||
if (abs_row_tensor != nullptr) {
|
||||
|
|
|
@ -502,6 +502,49 @@ py::object VectorToPyData(const Any &value) {
|
|||
}
|
||||
return ret;
|
||||
}
|
||||
template <typename T>
|
||||
py::object AbstractSequenceToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
|
||||
auto value_size = value_list.size();
|
||||
auto ret = T(value_size);
|
||||
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
bool dynamic_len = seq_abs->dynamic_len();
|
||||
auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs();
|
||||
if (dynamic_len || dynamic_len_element_abs != nullptr) {
|
||||
if (dynamic_len_element_abs == nullptr) {
|
||||
MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple.";
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
ret[i] = BaseRefToPyData(value_list[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
if (dynamic_len_element_abs->isa<abstract::AbstractNone>()) {
|
||||
MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty sequence.";
|
||||
return ret;
|
||||
}
|
||||
for (size_t i = 0; i < value_size; ++i) {
|
||||
ret[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
// If FALLBACK_RUNTIME is not enable
|
||||
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
|
||||
size_t ref_idx = 0;
|
||||
for (size_t i = 0; i < seq_abs->size(); i++) {
|
||||
auto elem_abs = seq_abs->elements()[i];
|
||||
if (elem_abs->isa<abstract::AbstractNone>() && !support_fallback_runtime) {
|
||||
continue;
|
||||
}
|
||||
ret[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
|
||||
ref_idx++;
|
||||
}
|
||||
if (ref_idx != value_size) {
|
||||
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
|
||||
<< ref_idx;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
|
||||
py::object ret;
|
||||
|
@ -516,55 +559,16 @@ py::object VectorRefToPyData(const VectorRef &value_list, const AbstractBasePtr
|
|||
}
|
||||
|
||||
// Current VectorRef reflects a COOTensor type
|
||||
MS_LOG(DEBUG) << "abs: " << abs->ToString();
|
||||
if (abs->isa<abstract::AbstractCSRTensor>()) {
|
||||
return MakeCSRTensor(value_list);
|
||||
}
|
||||
if (abs->isa<abstract::AbstractCOOTensor>()) {
|
||||
return MakeCOOTensor(value_list);
|
||||
}
|
||||
if (!abs->isa<abstract::AbstractSequence>()) {
|
||||
return VectorRefToPyData(value_list, nullptr);
|
||||
if (abs->isa<abstract::AbstractList>()) {
|
||||
return AbstractSequenceToPyData<py::list>(value_list, abs);
|
||||
}
|
||||
auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
|
||||
MS_EXCEPTION_IF_NULL(seq_abs);
|
||||
bool dynamic_len = seq_abs->dynamic_len();
|
||||
auto dynamic_len_element_abs = seq_abs->dynamic_len_element_abs();
|
||||
if (dynamic_len || dynamic_len_element_abs != nullptr) {
|
||||
if (dynamic_len_element_abs == nullptr) {
|
||||
MS_LOG(INFO) << "Dynamic length sequence with no specified element abstract convert to empty tuple.";
|
||||
for (size_t i = 0; i < value_size; i++) {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i]);
|
||||
}
|
||||
return ref_tuple;
|
||||
}
|
||||
if (dynamic_len_element_abs->isa<abstract::AbstractNone>()) {
|
||||
MS_LOG(INFO) << "Dynamic length sequence with element None convert to empty tuple.";
|
||||
return ref_tuple;
|
||||
}
|
||||
for (size_t i = 0; i < value_size; ++i) {
|
||||
ref_tuple[i] = BaseRefToPyData(value_list[i], dynamic_len_element_abs);
|
||||
}
|
||||
return ref_tuple;
|
||||
}
|
||||
static const auto support_fallback_runtime = (common::GetEnv("MS_DEV_ENABLE_FALLBACK_RUNTIME") != "0");
|
||||
// If FALLBACK_RUNTIME is not enable
|
||||
// The size of seq_abs may be larger than the size of value_list, because the backend will eliminate None.
|
||||
size_t ref_idx = 0;
|
||||
for (size_t i = 0; i < seq_abs->size(); i++) {
|
||||
auto elem_abs = seq_abs->elements()[i];
|
||||
if (elem_abs->isa<abstract::AbstractNone>() && !support_fallback_runtime) {
|
||||
continue;
|
||||
}
|
||||
ref_tuple[ref_idx] = BaseRefToPyData(value_list[ref_idx], elem_abs);
|
||||
ref_idx++;
|
||||
}
|
||||
if (ref_idx != value_size) {
|
||||
MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
|
||||
<< ref_idx;
|
||||
}
|
||||
ret = ref_tuple;
|
||||
return ret;
|
||||
return AbstractSequenceToPyData<py::tuple>(value_list, abs);
|
||||
}
|
||||
|
||||
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
|
|
|
@ -0,0 +1,558 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Test return list type object from graph"""
|
||||
import os
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import ops
|
||||
from mindspore.common import mutable
|
||||
from mindspore import Tensor, jit, context
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [1, 2, 3, 4]
|
||||
|
||||
res = foo()
|
||||
assert res == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return ["a", "b", "c", "d"]
|
||||
|
||||
res = foo()
|
||||
assert res == ["a", "b", "c", "d"]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_3():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [True, False, False, True]
|
||||
|
||||
res = foo()
|
||||
assert res == [True, False, False, True]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_4():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [Tensor([1]), Tensor([1, 2, 3]), Tensor([2, 3])]
|
||||
|
||||
res = foo()
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0].asnumpy() == np.array([1]))
|
||||
assert np.all(res[1].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(res[2].asnumpy() == np.array([2, 3]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_5():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [None, None, None]
|
||||
|
||||
res = foo()
|
||||
assert res == [None, None, None]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_6():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [np.array([1, 2, 3]), np.array([4, 5, 6]), 1]
|
||||
|
||||
res = foo()
|
||||
assert isinstance(res, list)
|
||||
assert len(res) == 3
|
||||
assert np.all(res[0] == np.array([1, 2, 3]))
|
||||
assert np.all(res[1] == np.array([4, 5, 6]))
|
||||
assert res[2] == 1
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_constant_list_7():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return constant list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [1, "a", True, None, Tensor([2])]
|
||||
|
||||
res = foo()
|
||||
assert res == [1, "a", True, None, Tensor([2])]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_node():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list node.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
os.environ["GRAPH_OP_RUN"] = "1"
|
||||
@jit
|
||||
def foo(x):
|
||||
return [x, x+1, x+2, 1]
|
||||
|
||||
res = foo(mutable(1))
|
||||
assert res == [1, 2, 3, 1]
|
||||
os.environ["GRAPH_OP_RUN"] = "0"
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_node_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list node.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
return [x, x+1, x+2, Tensor([4])]
|
||||
|
||||
res = foo(Tensor([1]))
|
||||
assert res == [Tensor([1]), Tensor([2]), Tensor([3]), Tensor([4])]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_node_3():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list node.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
return [x, mutable(1), "a"]
|
||||
|
||||
res = foo(Tensor([1]))
|
||||
assert res == [Tensor([1]), 1, "a"]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_with_nest():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list in nest scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
res = foo()
|
||||
assert res == [[1, 2, 3], [4, 5, 6]]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_with_nest_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list in nest scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]]
|
||||
|
||||
res = foo()
|
||||
assert res == [([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6]]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_with_nest_3():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list in nest scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6])
|
||||
|
||||
res = foo()
|
||||
assert res == (([1, 1], [2, 2], (3, [4, 4])), [4, 5, 6])
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_with_nest():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list in nest scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
return [[x, x], (x+1, x+2)]
|
||||
|
||||
res = foo(Tensor([0]))
|
||||
assert res == [[Tensor([0]), Tensor([0])], (Tensor([1]), Tensor([2]))]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_make_list_with_nest_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return make list in nest scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
return [x, ([x, 1],)], (x+1, x+2)
|
||||
|
||||
res = foo(Tensor([0]))
|
||||
assert res == ([Tensor([0]), ([Tensor([0]), 1],)], (Tensor([1]), Tensor([2])))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_buildin_list_func():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return result of list() function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
return list((1, "2", None, Tensor([1])))
|
||||
|
||||
res = foo()
|
||||
assert res == [1, "2", None, Tensor([1])]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_buildin_list_func_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return result of list() function.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
return list(x)
|
||||
|
||||
res = foo(Tensor([1, 2, 3]))
|
||||
assert res == [Tensor([1]), Tensor([2]), Tensor([3])]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_dynamic_length_list():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = mutable([1, 2, 3], True)
|
||||
return x
|
||||
|
||||
res = foo()
|
||||
assert res == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_dynamic_length_list_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return dynamic length list.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(m):
|
||||
x = mutable([m, m+1], True)
|
||||
return x
|
||||
|
||||
res = foo(Tensor([0]))
|
||||
assert res == [Tensor([0]), Tensor([1])]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_from_third_party():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return list from third party.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
m = np.array([1, 2, 3, 4])
|
||||
x = m.tolist()
|
||||
return x
|
||||
|
||||
res = foo()
|
||||
assert res == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Getattr for interpret node failed.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_from_third_party_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return list from third party.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(m):
|
||||
x = m.asnumpy().tolist()
|
||||
return x
|
||||
|
||||
res = foo(Tensor([1, 2, 3, 4]))
|
||||
assert res == [1, 2, 3, 4]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_from_third_party_3():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return list from third party.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo():
|
||||
x = np.arange(0, 10, 2)
|
||||
return list(x)
|
||||
|
||||
res = foo()
|
||||
assert res == [0, 2, 4, 6, 8]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_from_dict_attribute():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return list from dict keys and values.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x, y):
|
||||
m = {"1": x, "2": y}
|
||||
return list(m.keys()), list(m.values())
|
||||
|
||||
res = foo(Tensor([1]), mutable(2))
|
||||
assert len(res) == 2
|
||||
assert res[0] == ["1", "2"]
|
||||
assert res[1] == [Tensor([1]), 2]
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Return dict change the abstract.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_return_list_from_dict_attribute_2():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support return list from dict keys and values.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x, y):
|
||||
m = {"1": x, "2": y}
|
||||
return m, list(m.keys()), list(m.values())
|
||||
|
||||
res = foo(Tensor([1]), mutable(2))
|
||||
assert len(res) == 3
|
||||
assert res[1] == ["1", "2"]
|
||||
assert res[2] == [Tensor([1]), 2]
|
||||
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_for_return_list_graph():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support calculate gradient for graph with list return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(x):
|
||||
y = ops.ReLU()(x)
|
||||
return [y,]
|
||||
|
||||
x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
|
||||
res = ops.grad(foo)(x)
|
||||
assert np.allclose(res.asnumpy(), np.array([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]).astype(np.float32))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_grad_for_graph_with_list_input():
|
||||
"""
|
||||
Feature: Return list in graph
|
||||
Description: Support calculate gradient for graph with list return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@jit
|
||||
def foo(t):
|
||||
x = t[0]
|
||||
y = t[1]
|
||||
out = ops.MatMul()(x, y)
|
||||
return out
|
||||
|
||||
t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = ops.grad(foo)(t)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert np.allclose(output[0].asnumpy(), expect[0])
|
||||
assert np.allclose(output[1].asnumpy(), expect[1])
|
|
@ -37,7 +37,7 @@ def test_fallback_list_with_input_constant_tensor():
|
|||
x.append(Tensor([4]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 4
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
|
@ -66,7 +66,7 @@ def test_fallback_list_with_input_constant_tensor_2():
|
|||
x.append(Tensor([5, 6]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
|
@ -139,7 +139,7 @@ def test_fallback_tuple_with_input_constant_tensor_2():
|
|||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert isinstance(out, list)
|
||||
assert len(out) == 2
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
|
|
|
@ -269,6 +269,7 @@ def test_raise_with_variable_control_flow3():
|
|||
raise_info_joinedstr_tensor.value)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not support yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -119,7 +119,7 @@ def test_grad_mutable_list_tensor():
|
|||
t = mutable([Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
|
@ -310,7 +310,7 @@ def test_grad_mutable_list_tuple_tensor():
|
|||
Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [[np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32), np.array([[0, 0, 0],
|
||||
[0, 0, 0]]).astype(np.float32)],
|
||||
|
@ -423,6 +423,7 @@ def test_grad_mutable_dict_tuple_tensor():
|
|||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Do not support yet")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -467,7 +468,7 @@ def test_grad_mutable_list_dict_tensor():
|
|||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
assert len(output) == 2
|
||||
assert isinstance(output[0], dict)
|
||||
assert len(output[0].keys()) == 2
|
||||
|
@ -579,7 +580,7 @@ def test_grad_mutable_list_tensor_jit_function():
|
|||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
output = GradOperation()(net)(z)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
|
@ -673,7 +674,7 @@ def test_grad_mutable_unused_list_tensor():
|
|||
Tensor([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], dtype=mstype.float32),
|
||||
Tensor([[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], dtype=mstype.float32)])
|
||||
output = GradNetWrtX(Net())(t)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[3., 3., 3.],
|
||||
[3., 3., 3.]]).astype(np.float32),
|
||||
np.array([[0., 0., 0.],
|
||||
|
|
|
@ -279,7 +279,7 @@ def test_grad_const_list_tensor_to_mutable():
|
|||
|
||||
grad_net = GradNetWrtX(Net())
|
||||
output = grad_net()
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
|
@ -288,7 +288,7 @@ def test_grad_const_list_tensor_to_mutable():
|
|||
assert compare(output, expect)
|
||||
grad_net = GradNetWrtX1(Net())
|
||||
output = grad_net()
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
|
@ -338,7 +338,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable():
|
|||
x = [Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32),
|
||||
Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)]
|
||||
output = grad_net(x)
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[1.7, 1.7, 1.7],
|
||||
|
@ -347,6 +347,7 @@ def test_grad_const_tuple_or_list_tensor_arg_to_mutable():
|
|||
assert compare(output, expect)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="not support yet.")
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -397,7 +398,7 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
|
|||
|
||||
grad_net = GradNetWrtX(Net())
|
||||
output = grad_net()
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
expect = [(np.array([[1.4100001, 1.5999999, 6.6],
|
||||
[1.4100001, 1.5999999, 6.6]]).astype(np.float32),
|
||||
np.array([[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
|
||||
|
@ -405,10 +406,10 @@ def test_grad_const_list_and_tuple_tensor_to_mutable():
|
|||
np.array([[1.7, 1.7, 1.7],
|
||||
[1.9, 1.9, 1.9],
|
||||
[1.5, 1.5, 1.5]]).astype(np.float32)]
|
||||
assert compare(output, expect)
|
||||
assert compare(output, list)
|
||||
grad_net = GradNetWrtX1(Net())
|
||||
output = grad_net()
|
||||
assert isinstance(output, tuple)
|
||||
assert isinstance(output, list)
|
||||
assert compare(output, expect)
|
||||
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ def test_value_node_with_depend():
|
|||
x = [[1, 2, 3, 4], [5, 6, 7, 8]]
|
||||
net = NetValueNodeWithDepend()
|
||||
output = net(x)
|
||||
assert output == (5, 0, 7, 8)
|
||||
assert output == [5, 0, 7, 8]
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_single_if_no_else_type():
|
|||
|
||||
test_net = FalseNet()
|
||||
res = test_net()
|
||||
assert str(res) == "(<class 'numpy.ndarray'>, <class 'object'>)"
|
||||
assert str(res) == "[<class 'numpy.ndarray'>, <class 'object'>]"
|
||||
|
||||
|
||||
def test_single_if_no_else_type_2():
|
||||
|
@ -64,7 +64,7 @@ def test_single_if_no_else_type_2():
|
|||
|
||||
test_net = TrueNet()
|
||||
res = test_net()
|
||||
assert str(res) == "(<class 'int'>, <class 'object'>)"
|
||||
assert str(res) == "[<class 'int'>, <class 'object'>]"
|
||||
|
||||
|
||||
def test_single_if_1():
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_dict1():
|
|||
input_me = Tensor(input_np)
|
||||
net = Net1()
|
||||
out_me = net(input_me)
|
||||
assert out_me == ('x', 'y', 0, 1)
|
||||
assert out_me == ['x', 'y', 0, 1]
|
||||
|
||||
|
||||
def test_dict2():
|
||||
|
@ -76,7 +76,7 @@ def test_dict3():
|
|||
input_me = Tensor(input_np)
|
||||
net = Net3()
|
||||
out_me = net(input_me)
|
||||
assert out_me == ('x', 'y', 0, (0, 1))
|
||||
assert out_me == ['x', 'y', 0, (0, 1)]
|
||||
|
||||
|
||||
def test_dict4():
|
||||
|
|
|
@ -576,7 +576,7 @@ def test_list_double_slice():
|
|||
class NetInner(Cell):
|
||||
def construct(self, a, b, start1, stop1, step1, start2, stop2, step2):
|
||||
a[start1:stop1:step1][start2: stop2: step2] = b
|
||||
return tuple(a)
|
||||
return a
|
||||
|
||||
net = NetInner()
|
||||
a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
@ -728,7 +728,7 @@ def test_list_slice_negetive_step():
|
|||
a = [1, 2, 3, 4, 5]
|
||||
b = [11, 22, 33, 44, 55]
|
||||
a[-1:-4:-1] = b[-1:-4:-1]
|
||||
return tuple(a)
|
||||
return a
|
||||
|
||||
x = py_func()
|
||||
y = ms_func()
|
||||
|
@ -771,6 +771,6 @@ def test_list_slice_only_with_step():
|
|||
a = [1, 2, 3, 4]
|
||||
b = [11, 22]
|
||||
a[::2] = b
|
||||
return tuple(a)
|
||||
return a
|
||||
|
||||
assert ms_func() == py_func()
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_list_extend_1():
|
|||
x.extend(y)
|
||||
return x
|
||||
out = list_net_1()
|
||||
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
|
||||
assert np.all(out == [1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
|
||||
def test_list_extend_2():
|
||||
|
@ -52,7 +52,7 @@ def test_list_extend_2():
|
|||
x.extend(z)
|
||||
return x
|
||||
out = list_net_2()
|
||||
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 20))
|
||||
assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 20])
|
||||
|
||||
|
||||
def test_list_extend_3():
|
||||
|
@ -73,8 +73,8 @@ def test_list_extend_3():
|
|||
x.extend(z)
|
||||
return x
|
||||
out = list_net_3()
|
||||
assert np.all(out == (1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
|
||||
20, 4, Tensor(1), (1, 2), Tensor(1)))
|
||||
assert np.all(out == [1, 2, 3, 4, ('bb', '2', 3), 'Bob', 'a', ('Michael', 'Bob', '2'), \
|
||||
20, 4, Tensor(1), (1, 2), Tensor(1)])
|
||||
|
||||
|
||||
def test_list_extend_4():
|
||||
|
@ -90,7 +90,7 @@ def test_list_extend_4():
|
|||
x.extend(y)
|
||||
return x
|
||||
out = list_net_4()
|
||||
assert np.all(out == ())
|
||||
assert np.all(out == [])
|
||||
|
||||
|
||||
def test_list_extend_tuple():
|
||||
|
@ -107,4 +107,4 @@ def test_list_extend_tuple():
|
|||
return x
|
||||
|
||||
out = func()
|
||||
assert np.all(out == (1, 2, 3, 4, 5, 6, 7))
|
||||
assert np.all(out == [1, 2, 3, 4, 5, 6, 7])
|
||||
|
|
|
@ -132,7 +132,7 @@ def test_list_insert_pop_2():
|
|||
return x, y
|
||||
|
||||
res_x, res_y = list_insert_pop(-2)
|
||||
assert np.all(res_x == (3, 1, 4))
|
||||
assert np.all(res_x == [3, 1, 4])
|
||||
assert res_y == 3
|
||||
|
||||
|
||||
|
|
|
@ -40,9 +40,9 @@ def test_list_mul_number():
|
|||
"""
|
||||
net = Net()
|
||||
expect_ret0 = [Tensor([1, 2, 3])] * 5
|
||||
expect_ret1 = (Tensor([1, 2, 3]),) * 0
|
||||
assert isinstance(net()[0], tuple)
|
||||
assert isinstance(net()[1], tuple)
|
||||
expect_ret1 = [Tensor([1, 2, 3]),] * 0
|
||||
assert isinstance(net()[0], list)
|
||||
assert isinstance(net()[1], list)
|
||||
for i in range(len(net()[0])):
|
||||
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
|
||||
assert net()[1] == expect_ret1
|
||||
|
|
|
@ -33,7 +33,7 @@ def test_list_pop_1():
|
|||
return x, y
|
||||
|
||||
res_x, res_y = list_pop()
|
||||
assert np.all(res_x == (1, 2, 3))
|
||||
assert np.all(res_x == [1, 2, 3])
|
||||
assert res_y == 4
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ def test_list_pop_2():
|
|||
return x, y
|
||||
|
||||
res_x, res_y = list_pop()
|
||||
assert np.all(res_x == (1, 2, 4))
|
||||
assert np.all(res_x == [1, 2, 4])
|
||||
assert res_y == 3
|
||||
|
||||
|
||||
|
@ -67,7 +67,7 @@ def test_list_pop_3():
|
|||
return x, y
|
||||
|
||||
res_x, res_y = list_pop()
|
||||
assert np.all(res_x == (1, 3, 4))
|
||||
assert np.all(res_x == [1, 3, 4])
|
||||
assert res_y == 2
|
||||
|
||||
|
||||
|
@ -140,8 +140,8 @@ def test_list_pop_7():
|
|||
return x1, x2, y1 + y2
|
||||
|
||||
res_x1, res_x2, res_y = list_pop()
|
||||
assert np.all(res_x1 == (1, 3, 4))
|
||||
assert np.all(res_x2 == (5, 6, 8))
|
||||
assert np.all(res_x1 == [1, 3, 4])
|
||||
assert np.all(res_x2 == [5, 6, 8])
|
||||
assert res_y == 9
|
||||
|
||||
|
||||
|
@ -158,7 +158,7 @@ def test_list_pop_8():
|
|||
return x, y
|
||||
|
||||
res_x, res_y = list_pop(2)
|
||||
assert res_x == (Tensor([1]), Tensor([2]))
|
||||
assert res_x == [Tensor([1]), Tensor([2])]
|
||||
assert res_y == Tensor([3])
|
||||
|
||||
|
||||
|
@ -175,7 +175,7 @@ def test_list_pop_9():
|
|||
|
||||
input_x = [Tensor([1]), Tensor([2]), Tensor([3])]
|
||||
res_x, res_y = list_pop(input_x, 2)
|
||||
assert res_x == (Tensor([1]), Tensor([2]))
|
||||
assert res_x == [Tensor([1]), Tensor([2])]
|
||||
assert res_y == Tensor([3])
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ def test_list_reverse_1():
|
|||
x.reverse()
|
||||
return x
|
||||
out = list_net_1()
|
||||
assert np.all(out == (4, 3, 2, 1))
|
||||
assert np.all(out == [4, 3, 2, 1])
|
||||
|
||||
|
||||
def test_list_reverse_2():
|
||||
|
@ -48,7 +48,7 @@ def test_list_reverse_2():
|
|||
x.reverse()
|
||||
return x
|
||||
out = list_net_2()
|
||||
assert np.all(out == (4, 20, ('bb', '2', 3), 'a'))
|
||||
assert np.all(out == [4, 20, ('bb', '2', 3), 'a'])
|
||||
|
||||
|
||||
def test_list_reverse_3():
|
||||
|
@ -65,7 +65,7 @@ def test_list_reverse_3():
|
|||
x.reverse()
|
||||
return x
|
||||
out = list_net_3()
|
||||
assert np.all(out == (Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'))
|
||||
assert np.all(out == [Tensor(1), (1, 2), Tensor(1), 4, 20, ('Michael', 'Bob', '2'), 'a'])
|
||||
|
||||
|
||||
def test_list_reverse_4():
|
||||
|
@ -80,4 +80,4 @@ def test_list_reverse_4():
|
|||
x.reverse()
|
||||
return x
|
||||
out = list_net_4()
|
||||
assert np.all(out == ())
|
||||
assert np.all(out == [])
|
||||
|
|
|
@ -39,9 +39,9 @@ def test_number_mul_list():
|
|||
context.set_context(mode=context.GRAPH_MODE)
|
||||
net = Net()
|
||||
expect_ret0 = 5 * [Tensor([1, 2, 3])]
|
||||
expect_ret1 = 0 * (Tensor([1, 2, 3]),)
|
||||
assert isinstance(net()[0], tuple)
|
||||
assert isinstance(net()[1], tuple)
|
||||
expect_ret1 = 0 * [Tensor([1, 2, 3]),]
|
||||
assert isinstance(net()[0], list)
|
||||
assert isinstance(net()[1], list)
|
||||
for i in range(len(net()[0])):
|
||||
assert np.array_equal(net()[0][i].asnumpy(), expect_ret0[i].asnumpy())
|
||||
assert net()[1] == expect_ret1
|
||||
|
|
|
@ -30,7 +30,7 @@ def test_fallback_dict_empty():
|
|||
dict_x['a'] = [1, 2, 3]
|
||||
return dict_x["a"]
|
||||
|
||||
assert foo() == (1, 2, 3)
|
||||
assert foo() == [1, 2, 3]
|
||||
|
||||
|
||||
def test_fallback_dict_zip_iter_assign():
|
||||
|
|
|
@ -137,7 +137,7 @@ def test_fallback_reversed():
|
|||
def foo():
|
||||
x = reversed([1, 2, 3])
|
||||
return list(x)
|
||||
assert foo() == (3, 2, 1)
|
||||
assert foo() == [3, 2, 1]
|
||||
|
||||
|
||||
def test_fallback_set():
|
||||
|
|
|
@ -35,8 +35,8 @@ def test_fallback_list_with_input_tuple():
|
|||
x.append(4)
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert operator.eq(out, (1, 2, 3, 4))
|
||||
assert isinstance(out, list)
|
||||
assert operator.eq(out, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_fallback_list_with_input_list():
|
||||
|
@ -51,8 +51,8 @@ def test_fallback_list_with_input_list():
|
|||
x.append(4)
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert operator.eq(out, (1, 2, 3, 4))
|
||||
assert isinstance(out, list)
|
||||
assert operator.eq(out, [1, 2, 3, 4])
|
||||
|
||||
|
||||
def test_fallback_list_with_input_dict():
|
||||
|
@ -67,8 +67,8 @@ def test_fallback_list_with_input_dict():
|
|||
x.append('d')
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert operator.eq(out, ('a', 'b', 'c', 'd'))
|
||||
assert isinstance(out, list)
|
||||
assert operator.eq(out, ['a', 'b', 'c', 'd'])
|
||||
|
||||
|
||||
def test_fallback_list_with_input_numpy_array():
|
||||
|
|
|
@ -131,7 +131,7 @@ def test_fallback_max_with_two_inputs_list():
|
|||
x = max([1, 2, 3], [4, 5])
|
||||
return x
|
||||
out = foo()
|
||||
assert operator.eq(out, (4, 5))
|
||||
assert operator.eq(out, [4, 5])
|
||||
|
||||
|
||||
def test_fallback_min_with_two_inputs_list():
|
||||
|
@ -145,7 +145,7 @@ def test_fallback_min_with_two_inputs_list():
|
|||
x = min([1, 2, 3], [4, 5])
|
||||
return x
|
||||
out = foo()
|
||||
assert operator.eq(out, (1, 2, 3))
|
||||
assert operator.eq(out, [1, 2, 3])
|
||||
|
||||
|
||||
def test_builtin_function_max_min_with_string():
|
||||
|
|
|
@ -53,21 +53,21 @@ def get_list_comp_5():
|
|||
@jit
|
||||
def get_generator_exp_1():
|
||||
t = (x for x in range(1, 6))
|
||||
return t
|
||||
return tuple(t)
|
||||
|
||||
|
||||
@jit
|
||||
def get_generator_exp_2():
|
||||
t = (x * x for x in range(1, 11) if x > 5 if x % 2 == 0)
|
||||
return t
|
||||
return tuple(t)
|
||||
|
||||
|
||||
def test_list_comp():
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
assert get_list_comp_1() == (1, 2, 3, 4, 5)
|
||||
assert get_list_comp_2() == (1, 4, 9, 16, 25)
|
||||
assert get_list_comp_3() == (4, 16, 36, 64, 100)
|
||||
assert get_list_comp_4() == (36, 64, 100)
|
||||
assert get_list_comp_1() == [1, 2, 3, 4, 5]
|
||||
assert get_list_comp_2() == [1, 4, 9, 16, 25]
|
||||
assert get_list_comp_3() == [4, 16, 36, 64, 100]
|
||||
assert get_list_comp_4() == [36, 64, 100]
|
||||
with pytest.raises(TypeError) as ex:
|
||||
get_list_comp_5()
|
||||
assert "The 'generators' supports 1 'comprehension' in ListComp/GeneratorExp" in str(ex.value)
|
||||
|
|
|
@ -41,4 +41,4 @@ class NetWork(Cell):
|
|||
list1 = [1, 2, 3]
|
||||
net1 = NetWork()
|
||||
result = net1(list1)
|
||||
assert result == (1, 3)
|
||||
assert result == [1, 3]
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_hypermap_value():
|
|||
return map(lambda a: a + 1, self._list)
|
||||
|
||||
net = Net()
|
||||
assert net() == (23, 67, 89, 112)
|
||||
assert net() == [23, 67, 89, 112]
|
||||
|
||||
|
||||
def test_hypermap_func_const():
|
||||
|
@ -80,4 +80,4 @@ def test_hypermap_func_const():
|
|||
return map(lambda f: f(4), _list)
|
||||
|
||||
net = NetMap()
|
||||
assert net() == (8, 12, 16)
|
||||
assert net() == [8, 12, 16]
|
||||
|
|
|
@ -66,7 +66,7 @@ def test_output_const_list():
|
|||
return ret
|
||||
|
||||
net = Net()
|
||||
assert net() == (1, 2, 3)
|
||||
assert net() == [1, 2, 3]
|
||||
|
||||
|
||||
def test_output_const_int():
|
||||
|
|
Loading…
Reference in New Issue