Support list tuple in graph

This commit is contained in:
liangzhibo 2022-08-16 16:40:23 +08:00
parent 03eff82849
commit f75ff85f91
12 changed files with 284 additions and 91 deletions

View File

@ -128,6 +128,7 @@
"mindspore/tests/ut/python/rewrite/test_node.py" "syntax-error"
"mindspore/tests/ut/python/rewrite/test_node.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_for.py" "protected-access"
"mindspore/tests/ut/python/fallback/python_builtin/test_graph_fallback_list_tuple.py" "len-as-condition"
"mindspore/tests/ut/python/rewrite/test_symbol_tree.py" "len-as-condition"
"mindspore/tests/ut/python/rewrite/test_lenet.py" "protected-access"
"mindspore/tests/ut/python/rewrite/test_if.py" "protected-access"

View File

@ -80,6 +80,7 @@ const char PYTHON_MOD_EVAL_PY_SCRIPT[] = "eval_script";
const char PYTHON_MOD_GET_SCRIPT_IDS[] = "get_script_ids";
const char PYTHON_MOD_PYTHON_ISINSTANCE[] = "python_isinstance";
const char PYTHON_MOD_MS_ISINSTANCE[] = "ms_isinstance";
const char PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION[] = "convert_class_to_function";
const char PYTHON_PARSE_GET_ARGS[] = "get_args";
const char PYTHON_PARSE_GET_ARGS_DEFAULT_VALUES[] = "get_args_default_values";

View File

@ -148,6 +148,7 @@ BuiltInTypeMap &GetMethodMap() {
{"__len__", prim::kPrimDictLen}, // P.dict_len
{"__getitem__", prim::kPrimDictGetItem}, // P.dict_getitem
{"__setitem__", prim::kPrimDictSetItem}, // P.dict_setitem,
{"__ms_iter__", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"keys", prim::kPrimDictGetKeys}, // P.dict_getkeys,
{"values", prim::kPrimDictGetValues}, // P.dict_getvalues,
{"items", prim::kPrimDictItems}, // P.dict_items

View File

@ -306,6 +306,36 @@ void CheckInterpretedObject(const AbstractBasePtr &abs) {
}
}
EvalResultPtr ConvertClassToFunc(const CNodePtr &cnode, const AbstractBasePtr &abs, const AnfNodeConfigPtr &conf) {
auto val = abs->BuildValue();
auto class_val = dyn_cast_ptr<parse::ClassType>(val);
const auto &class_name = class_val->name();
py::module mod = python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
auto py_fn = python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_CONVERT_CLASS_TO_FUNCTION, py::str(class_name));
if (py::isinstance<py::none>(py_fn)) {
MS_LOG(ERROR) << "Can not cast to a AbstractFunction from " << abs->ToString() << ".";
MS_LOG(ERROR) << "It's called at: " << cnode->DebugString();
MS_EXCEPTION(ValueError) << "This may be not defined, or it can't be a operator. Please check code.";
}
auto list_func_fg = parse::ParsePythonCode(py_fn);
auto fg = cnode->func_graph();
list_func_fg->set_manager(fg->manager());
auto &inputs = cnode->inputs();
std::vector<AnfNodePtr> new_cnode_inputs;
(void)new_cnode_inputs.emplace_back(NewValueNode(list_func_fg));
for (std::size_t i = 1; i < inputs.size(); ++i) {
(void)new_cnode_inputs.emplace_back(inputs[i]);
}
auto new_cnode = fg->NewCNodeInOrder(new_cnode_inputs);
fg->ReplaceInOrder(cnode, new_cnode);
AnalysisEnginePtr eng = conf->engine();
MS_EXCEPTION_IF_NULL(eng);
AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_cnode, conf->context(), conf->func_graph());
return eng->ForwardConfig(conf, fn_conf);
}
EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) {
MS_EXCEPTION_IF_NULL(conf);
MS_EXCEPTION_IF_NULL(cnode);
@ -315,6 +345,15 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
return std::make_shared<EvalResult>(possible_func->Clone(), std::make_shared<AttrValueMap>());
}
if (possible_func->isa<AbstractScalar>()) {
// Convert class to function, such as list(xxx).
auto val = possible_func->BuildValue();
MS_EXCEPTION_IF_NULL(val);
if (val->isa<parse::ClassType>()) {
return ConvertClassToFunc(cnode, possible_func, conf);
}
}
auto func = dyn_cast_ptr<AbstractFunction>(possible_func);
if (func == nullptr) {
CheckInterpretedObject(possible_func);
@ -325,8 +364,8 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
bool contains_isolated_side_effect = false;
ConfigPtrList args_conf_list;
// Ignore the first node which is function name
auto &inputs = cnode->inputs();
// Ignore the first node which is function name
for (std::size_t i = 1; i < inputs.size(); i++) {
const AnfNodePtr &node = inputs[i];
args_conf_list.push_back(MakeConfig(node, conf->context(), conf->func_graph()));

View File

@ -25,7 +25,7 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type,
eval_script, get_script_ids, expand_expr_statement, is_class_member, parse_cb, resolve_symbol,
convert_to_ms_tensor, get_object_description, get_class_attr_namespace_symbol, get_ms_class_name,
is_class_type, check_obj_bool, python_isinstance, ms_isinstance, convert_to_ms_csrtensor,
convert_to_ms_cootensor)
convert_to_ms_cootensor, convert_class_to_function)
__all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'generate_scope',
'get_bprop_method_of_class', 'get_class_instance_type', 'get_class_member_namespace_symbol',
@ -35,4 +35,4 @@ __all__ = ['Parser', 'create_instance', 'is_supported_create_instance_type', 'ge
'eval_script', 'get_script_ids', 'expand_expr_statement', 'is_class_member', 'parse_cb', 'resolve_symbol',
'convert_to_ms_tensor', 'get_object_description', 'get_class_attr_namespace_symbol', 'get_ms_class_name',
'is_class_type', 'check_obj_bool', 'python_isinstance', 'ms_isinstance', 'convert_to_ms_csrtensor',
'convert_to_ms_cootensor']
'convert_to_ms_cootensor', 'convert_class_to_function']

View File

@ -40,7 +40,8 @@ from mindspore.common.api import _MindsporeFunctionExecutor, _convert_python_dat
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from .namespace import Namespace, CellNamespace, ClosureNamespace, ClassMemberNamespace, ClassAttrNamespace
from .resources import parse_object_map, ops_symbol_map, convert_object_map, trope_ns, SYMBOL_UNDEFINE, NO_IMPLEMENT
from .resources import parse_object_map, ops_symbol_map, convert_object_map, convert_class_to_function_map, trope_ns
from .resources import SYMBOL_UNDEFINE, NO_IMPLEMENT
from .jit_fallback_modules import jit_fallback_third_party_modules_whitelist
# Define return value
@ -99,7 +100,7 @@ _builtin_function_or_method_type = type(abs)
# Unsupported python builtin type in graph mode.
_unsupported_python_builtin_type = (
list, tuple, set, dict, slice, bool, int, float, str, complex, reversed, type,
set, dict, slice, bool, int, float, str, complex, reversed, type,
)
_unsupported_internal_type = (
@ -107,7 +108,7 @@ _unsupported_internal_type = (
)
_hybrid_type = (
print, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr
print, enumerate, zip, map, filter, abs, all, any, round, max, min, hasattr, list, tuple
)
# Unsupported python builtin type in JIT Fallback.
@ -425,6 +426,11 @@ def create_instance(cls_type, params=None):
return obj
def convert_class_to_function(cls_str):
"""Convert class to function."""
return convert_class_to_function_map.get(cls_str)
def python_isinstance(x, cmp_type):
"""Python isinstance function."""
# Convert _c_expression tensor to python tensor.

View File

@ -175,3 +175,9 @@ convert_object_map = {
if not security.enable_security():
convert_object_map[T.print] = F.print_
# Convert class object to callable function
convert_class_to_function_map = {
"class 'list'": M.list_func,
"class 'tuple'": M.tuple_func
}

View File

@ -17,7 +17,7 @@
"""standard_method"""
from __future__ import absolute_import
from mindspore import Tensor, CSRTensor, COOTensor, ms_class
from mindspore import Tensor, CSRTensor, COOTensor, RowTensor, ms_class
from mindspore import dtype as mstype
from ..._checkparam import Validator as validator
@ -1860,6 +1860,48 @@ def ms_round(*data):
return constant_round(*data)
def list_func(*data):
"""Implementation of `list`."""
data_len = len(data)
if data_len >= 2:
const_utils.raise_type_error("list() requires 0 or 1 arguments.")
if data_len == 0:
return F.make_list()
data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
const_utils.raise_type_error("list() does not support single sparse tensor input.")
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
data_type = F.typeof(data)
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
if isinstance(data, dict):
data = data.keys()
ret = F.make_list()
for i in range(len(data)):
ret = ret + F.make_list(data[i])
return ret
def tuple_func(*data):
"""Implementation of `tuple`."""
data_len = len(data)
if data_len >= 2:
const_utils.raise_type_error("tuple() requires 0 or 1 arguments.")
if data_len == 0:
return F.make_tuple()
data = data[0]
if isinstance(data, (CSRTensor, COOTensor, RowTensor)):
const_utils.raise_type_error("tuple() does not support single sparse tensor input.")
if not isinstance(data, Tensor) and not hasattr(data, "__ms_iter__"):
data_type = F.typeof(data)
const_utils.raise_type_error(str(data_type) + " object is not iterable.")
if isinstance(data, dict):
data = data.keys()
ret = F.make_tuple()
for i in range(len(data)):
ret = ret + F.make_tuple(data[i])
return ret
def max_tensor(*data):
"""Get the max of tensor inputs."""
max_tensor_data = data[0]

View File

@ -29,7 +29,7 @@ from operator import ( # noqa
# support system function call
from builtins import ( # noqa
bool, getattr, setattr, hasattr, len, iter, next, pow, range, map, zip,
print, enumerate, isinstance, filter, abs, all, any, round, max, min
print, enumerate, isinstance, filter, abs, all, any, round, max, min, list, tuple
)
# support functools
@ -47,7 +47,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt',
'matmul', 'getitem', 'setitem',
'bool', 'getattr', 'setattr', 'hasattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip',
'partial', 'print', 'enumerate', 'isinstance', 'filter', 'abs', 'all', 'any', 'round',
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min']
'exp', 'log', 'sin', 'cos', 'tan', 'max', 'min', 'list', 'tuple']
def MakeTuple(*elts): # pragma: no cover

View File

@ -661,11 +661,10 @@ def csr_add(a, b, alpha, beta):
b_batch_pointers = make_tensor([0, b.values.shape[0]], mstype.int32)
a_shape = make_tensor(a.shape, mstype.int32)
b_shape = make_tensor(b.shape, mstype.int32)
shape, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
alpha, beta)
output_shape = tuple(shape.asnumpy().tolist())
return CSRTensor(indptr, indices, values, output_shape)
_, _, indptr, indices, values = csr_add_op(a_shape, a_batch_pointers, a.indptr, a.indices, a.values,
b_shape, b_batch_pointers, b.indptr, b.indices, b.values,
alpha, beta)
return CSRTensor(indptr, indices, values, a.shape)
__all__ = [

View File

@ -238,3 +238,151 @@ def test_builtin_function_max_min_with_tensor_list():
min_out, max_out = foo(Tensor([1, 2, 3, 4, 5], dtype=mstype.float32))
assert operator.eq(min_out, 1)
assert operator.eq(max_out, 5)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_list_with_input_constant_tensor():
"""
Feature: JIT Fallback
Description: Test list() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([1, 2, 3]))
x.append(Tensor([4]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 4
assert isinstance(out[0], Tensor)
assert out[0].asnumpy() == 1
assert isinstance(out[1], Tensor)
assert out[1].asnumpy() == 2
assert isinstance(out[2], Tensor)
assert out[2].asnumpy() == 3
assert isinstance(out[3], Tensor)
assert out[3].asnumpy() == 4
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_list_with_input_constant_tensor_2():
"""
Feature: JIT Fallback
Description: Test list() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([[1, 2], [3, 4]]))
x.append(Tensor([5, 6]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 3
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
assert isinstance(out[1], Tensor)
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
assert isinstance(out[2], Tensor)
assert np.allclose(out[2].asnumpy(), np.array([5, 6]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_builtin_function_list_with_non_constant_tensor():
"""
Feature: Graph list function.
Description: When the input to list() is non constant tensor, list function will return correct result.
Expectation: No exception.
"""
@ms_function
def foo(x):
return list(x)
ret = foo(Tensor([[1, 2, 3], [4, 5, 6]]))
assert len(ret) == 2
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_tuple_with_input_constant_tensor():
"""
Feature: JIT Fallback
Description: Test tuple() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = tuple(Tensor([1, 2, 3]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 3
assert isinstance(out[0], Tensor)
assert out[0].asnumpy() == 1
assert isinstance(out[1], Tensor)
assert out[1].asnumpy() == 2
assert isinstance(out[2], Tensor)
assert out[2].asnumpy() == 3
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_fallback_tuple_with_input_constant_tensor_2():
"""
Feature: JIT Fallback
Description: Test tuple() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([[1, 2], [3, 4]]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 2
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
assert isinstance(out[1], Tensor)
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_builtin_function_tuple_with_non_constant_tensor():
"""
Feature: Graph tuple function.
Description: When the input to tuple() is non constant tensor, list function will return correct result.
Expectation: No exception.
"""
@ms_function
def foo(x):
return tuple(x)
ret = foo(Tensor([[1, 2, 3], [4, 5, 6]]))
assert len(ret) == 2
assert np.all(ret[0].asnumpy() == np.array([1, 2, 3]))
assert np.all(ret[1].asnumpy() == np.array([4, 5, 6]))

View File

@ -83,6 +83,24 @@ def test_fallback_list_with_input_numpy_array():
assert np.allclose(np.array([1, 2, 3, 4]), out.asnumpy())
def test_fallback_list_with_empty_input():
"""
Feature: JIT Fallback
Description: Test list() in graph mode with empty input.
Expectation: No exception.
"""
@ms_function
def foo():
x = list()
if isinstance(x, list):
if len(x) == 0:
return 1
return 2
return 3
out = foo()
assert out == 1
def test_fallback_list_with_input_number():
"""
Feature: JIT Fallback
@ -98,52 +116,6 @@ def test_fallback_list_with_input_number():
assert "object is not iterable" in str(ex.value)
def test_fallback_list_with_input_constant_tensor():
"""
Feature: JIT Fallback
Description: Test list() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([1, 2, 3]))
x.append(Tensor([4]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 4
assert isinstance(out[0], Tensor)
assert out[0].asnumpy() == 1
assert isinstance(out[1], Tensor)
assert out[1].asnumpy() == 2
assert isinstance(out[2], Tensor)
assert out[2].asnumpy() == 3
assert isinstance(out[3], Tensor)
assert out[3].asnumpy() == 4
def test_fallback_list_with_input_constant_tensor_2():
"""
Feature: JIT Fallback
Description: Test list() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([[1, 2], [3, 4]]))
x.append(Tensor([5, 6]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 3
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
assert isinstance(out[1], Tensor)
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
assert isinstance(out[2], Tensor)
assert np.allclose(out[2].asnumpy(), np.array([5, 6]))
def test_fallback_tuple_with_input_list():
"""
Feature: JIT Fallback
@ -203,44 +175,22 @@ def test_fallback_tuple_with_input_numpy_array():
assert np.allclose(np.array([1, 2, 3]), out.asnumpy())
def test_fallback_tuple_with_input_constant_tensor():
def test_fallback_tuple_with_empty_input():
"""
Feature: JIT Fallback
Description: Test tuple() in graph mode with constant tensor.
Description: Test tuple() in graph mode with empty input.
Expectation: No exception.
"""
@ms_function
def foo():
x = tuple(Tensor([1, 2, 3]))
return x
x = tuple()
if isinstance(x, tuple):
if len(x) == 0:
return 1
return 2
return 3
out = foo()
assert isinstance(out, tuple)
assert len(out) == 3
assert isinstance(out[0], Tensor)
assert out[0].asnumpy() == 1
assert isinstance(out[1], Tensor)
assert out[1].asnumpy() == 2
assert isinstance(out[2], Tensor)
assert out[2].asnumpy() == 3
def test_fallback_tuple_with_input_constant_tensor_2():
"""
Feature: JIT Fallback
Description: Test tuple() in graph mode with constant tensor.
Expectation: No exception.
"""
@ms_function
def foo():
x = list(Tensor([[1, 2], [3, 4]]))
return x
out = foo()
assert isinstance(out, tuple)
assert len(out) == 2
assert isinstance(out[0], Tensor)
assert np.allclose(out[0].asnumpy(), np.array([1, 2]))
assert isinstance(out[1], Tensor)
assert np.allclose(out[1].asnumpy(), np.array([3, 4]))
assert out == 1
def test_fallback_tuple_with_input_number():