forked from mindspore-Ecosystem/mindspore
Support list tuple in graph
This commit is contained in:
parent
03eff82849
commit
f75ff85f91
|
@ -128,6 +128,7 @@
|
|||
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
|
||||
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access"
|
||||
"mindspore/tests/ut/python/fallback/python_builtin/test_graph_fallback_list_tuple.py" "len-as-condition"
|
||||
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
|
||||
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
|
||||
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"
|
||||
|
|
|
@ -80,6 +80,7 @@ const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
|
|||
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
|
||||
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";
|
||||
const char PYTHON_MOD_MS_ISINSTANCE[] = "ms_isinstance";
|
||||
const char PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION[] = "convert_class_to_function";
|
||||
|
||||
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
|
||||
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";
|
||||
|
|
|
@ -148,6 +148,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"__len__", prim::kPrimDictLen}, // P.dict_len
|
||||
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
|
||||
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
|
||||
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
|
||||
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
|
||||
{"items", prim::kPrimDictItems}, // P.dict_items
|
||||
|
|
|
@ -306,6 +306,36 @@ void CheckInterpretedObject(const AbstractBasePtr &abs) {
|
|||
}
|
||||
}
|
||||
|
||||
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
|
||||
auto val = abs->BuildValue();
|
||||
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
|
||||
const auto &class_name = class_val->name();
|
||||
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
|
||||
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
|
||||
if (py::isinstance<py::none>(py_fn)) {
|
||||
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
|
||||
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
|
||||
MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
|
||||
}
|
||||
auto list_func_fg = parse::ParsePythonCode(py_fn);
|
||||
auto fg = cnode->func_graph();
|
||||
list_func_fg->set_manager(fg->manager());
|
||||
|
||||
auto &inputs = cnode->inputs();
|
||||
std::vector<AnfNodePtr> new_cnode_inputs;
|
||||
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
|
||||
for (std::size_t i = 1; i < inputs.size(); ++i) {
|
||||
(void)new_cnode_inputs.emplace_back(inputs[i]);
|
||||
}
|
||||
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
|
||||
fg->ReplaceInOrder(cnode, new_cnode);
|
||||
|
||||
AnalysisEnginePtr eng = conf->engine();
|
||||
MS_EXCEPTION_IF_NULL(eng);
|
||||
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
|
||||
return eng->ForwardConfig(conf, fn_conf);
|
||||
}
|
||||
|
||||
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
@ -315,6 +345,15 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
|
||||
}
|
||||
|
||||
if (possible_func->isa<AbstractScalar>()) {
|
||||
// Convert class to function, such as list(xxx).
|
||||
auto val = possible_func->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
if (val->isa<parse::ClassType>()) {
|
||||
return ConvertClassToFunc(cnode, possible_func, conf);
|
||||
}
|
||||
}
|
||||
|
||||
auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
|
||||
if (func == nullptr) {
|
||||
CheckInterpretedObject(possible_func);
|
||||
|
@ -325,8 +364,8 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
|
||||
bool contains_isolated_side_effect = false;
|
||||
ConfigPtrList args_conf_list;
|
||||
// Ignore the first node which is function name
|
||||
auto &inputs = cnode->inputs();
|
||||
// Ignore the first node which is function name
|
||||
for (std::size_t i = 1; i < inputs.size(); i++) {
|
||||
const AnfNodePtr &node = inputs[i];
|
||||
args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));
|
||||
|
|
|
@ -25,7 +25,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
|
|||
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
|
||||
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
|
||||
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
|
||||
convert_to_ms_cootensor)
|
||||
convert_to_ms_cootensor, convert_class_to_function)
|
||||
|
||||
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
|
||||
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
|
||||
|
@ -35,4 +35,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
|
|||
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
|
||||
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
|
||||
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
|
||||
'convert_to_ms_cootensor']
|
||||
'convert_to_ms_cootensor', 'convert_class_to_function']
|
||||
|
|
|
@ -40,7 +40,8 @@ from mindspore.common.api import _MindsporeFunctionExecutor, _convert_python_dat
|
|||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common.parameter import Parameter
|
||||
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from .resources import parse_object_map, ops_symbol_map, convert_object_map, convert_class_to_function_map, trope_ns
|
||||
from .resources import SYMBOL_UNDEFINE, NO_IMPLEMENT
|
||||
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
|
||||
|
||||
# Define return value
|
||||
|
@ -99,7 +100,7 @@ _builtin_function_or_method_type = type(abs)
|
|||
|
||||
# Unsupported python builtin type in graph mode.
|
||||
_unsupported_python_builtin_type = (
|
||||
list, tuple, set, dict, slice, bool, int, float, str, complex, reversed, type,
|
||||
set, dict, slice, bool, int, float, str, complex, reversed, type,
|
||||
)
|
||||
|
||||
_unsupported_internal_type = (
|
||||
|
@ -107,7 +108,7 @@ _unsupported_internal_type = (
|
|||
)
|
||||
|
||||
_hybrid_type = (
|
||||
print, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr
|
||||
print, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr, list, tuple
|
||||
)
|
||||
|
||||
# Unsupported python builtin type in JIT Fallback.
|
||||
|
@ -425,6 +426,11 @@ def create_instance(cls_type, params=None):
|
|||
return obj
|
||||
|
||||
|
||||
def convert_class_to_function(cls_str):
|
||||
"""Convert class to function."""
|
||||
return convert_class_to_function_map.get(cls_str)
|
||||
|
||||
|
||||
def python_isinstance(x, cmp_type):
|
||||
"""Python isinstance function."""
|
||||
# Convert _c_expression tensor to python tensor.
|
||||
|
|
|
@ -175,3 +175,9 @@ convert_object_map = {
|
|||
|
||||
if not security.enable_security():
|
||||
convert_object_map[T.print] = F.print_
|
||||
|
||||
# Convert class object to callable function
|
||||
convert_class_to_function_map = {
|
||||
"class 'list'": M.list_func,
|
||||
"class 'tuple'": M.tuple_func
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
"""standard_method"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from mindspore import Tensor, CSRTensor, COOTensor, ms_class
|
||||
from mindspore import Tensor, CSRTensor, COOTensor, RowTensor, ms_class
|
||||
from mindspore import dtype as mstype
|
||||
|
||||
from ..._checkparam import Validator as validator
|
||||
|
@ -1860,6 +1860,48 @@ def ms_round(*data):
|
|||
return constant_round(*data)
|
||||
|
||||
|
||||
def list_func(*data):
|
||||
"""Implementation of `list`."""
|
||||
data_len = len(data)
|
||||
if data_len >= 2:
|
||||
const_utils.raise_type_error("list() requires 0 or 1 arguments.")
|
||||
if data_len == 0:
|
||||
return F.make_list()
|
||||
data = data[0]
|
||||
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
|
||||
const_utils.raise_type_error("list() does not support single sparse tensor input.")
|
||||
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
||||
data_type = F.typeof(data)
|
||||
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
|
||||
if isinstance(data, dict):
|
||||
data = data.keys()
|
||||
ret = F.make_list()
|
||||
for i in range(len(data)):
|
||||
ret = ret + F.make_list(data[i])
|
||||
return ret
|
||||
|
||||
|
||||
def tuple_func(*data):
|
||||
"""Implementation of `tuple`."""
|
||||
data_len = len(data)
|
||||
if data_len >= 2:
|
||||
const_utils.raise_type_error("tuple() requires 0 or 1 arguments.")
|
||||
if data_len == 0:
|
||||
return F.make_tuple()
|
||||
data = data[0]
|
||||
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
|
||||
const_utils.raise_type_error("tuple() does not support single sparse tensor input.")
|
||||
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
|
||||
data_type = F.typeof(data)
|
||||
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
|
||||
if isinstance(data, dict):
|
||||
data = data.keys()
|
||||
ret = F.make_tuple()
|
||||
for i in range(len(data)):
|
||||
ret = ret + F.make_tuple(data[i])
|
||||
return ret
|
||||
|
||||
|
||||
def max_tensor(*data):
|
||||
"""Get the max of tensor inputs."""
|
||||
max_tensor_data = data[0]
|
||||
|
|
|
@ -29,7 +29,7 @@ from operator import ( # noqa
|
|||
# support system function call
|
||||
from builtins import ( # noqa
|
||||
bool, getattr, setattr, hasattr, len, iter, next, pow, range, map, zip,
|
||||
print, enumerate, isinstance, filter, abs, all, any, round, max, min
|
||||
print, enumerate, isinstance, filter, abs, all, any, round, max, min, list, tuple
|
||||
)
|
||||
|
||||
# support functools
|
||||
|
@ -47,7 +47,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
|
|||
'matmul', 'getitem', 'setitem',
|
||||
'bool', 'getattr', 'setattr', 'hasattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
|
||||
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
|
||||
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']
|
||||
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min', 'list', 'tuple']
|
||||
|
||||
|
||||
def MakeTuple(*elts): # pragma: no cover
|
||||
|
|
|
@ -661,11 +661,10 @@ def csr_add(a, b, alpha, beta):
|
|||
b_batch_pointers = make_tensor([0, b.values.shape[0]], mstype.int32)
|
||||
a_shape = make_tensor(a.shape, mstype.int32)
|
||||
b_shape = make_tensor(b.shape, mstype.int32)
|
||||
shape, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
|
||||
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
|
||||
alpha, beta)
|
||||
output_shape = tuple(shape.asnumpy().tolist())
|
||||
return CSRTensor(indptr, indices, values, output_shape)
|
||||
_, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
|
||||
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
|
||||
alpha, beta)
|
||||
return CSRTensor(indptr, indices, values, a.shape)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
|
@ -238,3 +238,151 @@ def test_builtin_function_max_min_with_tensor_list():
|
|||
min_out, max_out = foo(Tensor([1, 2, 3, 4, 5], dtype=mstype.float32))
|
||||
assert operator.eq(min_out, 1)
|
||||
assert operator.eq(max_out, 5)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_list_with_input_constant_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([1, 2, 3]))
|
||||
x.append(Tensor([4]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 4
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
assert isinstance(out[3], Tensor)
|
||||
assert out[3].asnumpy() == 4
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_list_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
x.append(Tensor([5, 6]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert np.allclose(out[2].asnumpy(), np.array([5, 6]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_builtin_function_list_with_non_constant_tensor():
|
||||
"""
|
||||
Feature: Graph list function.
|
||||
Description: When the input to list() is non constant tensor, list function will return correct result.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return list(x)
|
||||
|
||||
ret = foo(Tensor([[1, 2, 3], [4, 5, 6]]))
|
||||
assert len(ret) == 2
|
||||
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_tuple_with_input_constant_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = tuple(Tensor([1, 2, 3]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_fallback_tuple_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 2
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_builtin_function_tuple_with_non_constant_tensor():
|
||||
"""
|
||||
Feature: Graph tuple function.
|
||||
Description: When the input to tuple() is non constant tensor, list function will return correct result.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo(x):
|
||||
return tuple(x)
|
||||
|
||||
ret = foo(Tensor([[1, 2, 3], [4, 5, 6]]))
|
||||
assert len(ret) == 2
|
||||
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
|
||||
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))
|
||||
|
|
|
@ -83,6 +83,24 @@ def test_fallback_list_with_input_numpy_array():
|
|||
assert np.allclose(np.array([1, 2, 3, 4]), out.asnumpy())
|
||||
|
||||
|
||||
def test_fallback_list_with_empty_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with empty input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list()
|
||||
if isinstance(x, list):
|
||||
if len(x) == 0:
|
||||
return 1
|
||||
return 2
|
||||
return 3
|
||||
out = foo()
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_list_with_input_number():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -98,52 +116,6 @@ def test_fallback_list_with_input_number():
|
|||
assert "object is not iterable" in str(ex.value)
|
||||
|
||||
|
||||
def test_fallback_list_with_input_constant_tensor():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([1, 2, 3]))
|
||||
x.append(Tensor([4]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 4
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
assert isinstance(out[3], Tensor)
|
||||
assert out[3].asnumpy() == 4
|
||||
|
||||
|
||||
def test_fallback_list_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test list() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
x.append(Tensor([5, 6]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert np.allclose(out[2].asnumpy(), np.array([5, 6]))
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_list():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
|
@ -203,44 +175,22 @@ def test_fallback_tuple_with_input_numpy_array():
|
|||
assert np.allclose(np.array([1, 2, 3]), out.asnumpy())
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_constant_tensor():
|
||||
def test_fallback_tuple_with_empty_input():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Description: Test tuple() in graph mode with empty input.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = tuple(Tensor([1, 2, 3]))
|
||||
return x
|
||||
x = tuple()
|
||||
if isinstance(x, tuple):
|
||||
if len(x) == 0:
|
||||
return 1
|
||||
return 2
|
||||
return 3
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 3
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert out[0].asnumpy() == 1
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert out[1].asnumpy() == 2
|
||||
assert isinstance(out[2], Tensor)
|
||||
assert out[2].asnumpy() == 3
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_constant_tensor_2():
|
||||
"""
|
||||
Feature: JIT Fallback
|
||||
Description: Test tuple() in graph mode with constant tensor.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms_function
|
||||
def foo():
|
||||
x = list(Tensor([[1, 2], [3, 4]]))
|
||||
return x
|
||||
out = foo()
|
||||
assert isinstance(out, tuple)
|
||||
assert len(out) == 2
|
||||
assert isinstance(out[0], Tensor)
|
||||
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
|
||||
assert isinstance(out[1], Tensor)
|
||||
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
|
||||
assert out == 1
|
||||
|
||||
|
||||
def test_fallback_tuple_with_input_number():
|
||||
|
|
Loading…
Reference in New Issue